Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/core/src/agent/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ export class Agent<

this.onTaskStartTip = this.opts.onTaskStartTip;

this.insight = new Insight(async (action: InsightAction) => {
return this.getUIContext(action);
this.insight = new Insight(async () => {
return this.getUIContext();
});

// Process cache configuration
Expand Down
154 changes: 46 additions & 108 deletions packages/core/src/agent/tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,20 @@ import {
plan,
uiTarsPlanning,
} from '@/ai-model';
import { Executor } from '@/ai-model/action-executor';
import type { TMultimodalPrompt, TUserPrompt } from '@/ai-model/common';
import type { AbstractInterface } from '@/device';
import { Executor } from '@/executor';
import type Insight from '@/insight';
import type {
AIUsageInfo,
DetailedLocateParam,
DumpSubscriber,
ElementCacheFeature,
ExecutionRecorderItem,
ExecutionTaskActionApply,
ExecutionTaskApply,
ExecutionTaskHitBy,
ExecutionTaskInsightLocateApply,
ExecutionTaskInsightQueryApply,
ExecutionTaskPlanning,
ExecutionTaskPlanningApply,
ExecutionTaskProgressOptions,
ExecutorContext,
Expand Down Expand Up @@ -109,43 +107,6 @@ export class TaskExecutor {
this.conversationHistory = new ConversationHistory();
}

private async recordScreenshot(timing: ExecutionRecorderItem['timing']) {
const base64 = await this.interface.screenshotBase64();
const item: ExecutionRecorderItem = {
type: 'screenshot',
ts: Date.now(),
screenshot: base64,
timing,
};
return item;
}

private prependExecutorWithScreenshot(
taskApply: ExecutionTaskApply,
appendAfterExecution = false,
): ExecutionTaskApply {
const taskWithScreenshot: ExecutionTaskApply = {
...taskApply,
executor: async (param, context, ...args) => {
const recorder: ExecutionRecorderItem[] = [];
const { task } = context;
// set the recorder before executor in case of error
task.recorder = recorder;
const shot = await this.recordScreenshot(`before ${task.type}`);
recorder.push(shot);

const result = await taskApply.executor(param, context, ...args);

if (appendAfterExecution) {
const shot2 = await this.recordScreenshot('after Action');
recorder.push(shot2);
}
return result;
},
};
return taskWithScreenshot;
}

public async convertPlanToExecutable(
plans: PlanningAction[],
modelConfig: IModelConfig,
Expand Down Expand Up @@ -201,19 +162,9 @@ export class TaskExecutor {
}
};
this.insight.onceDumpUpdatedFn = dumpCollector;
const shotTime = Date.now();

// Get context through contextRetrieverFn which handles frozen context
const uiContext = await this.insight.contextRetrieverFn('locate');
task.uiContext = uiContext;

const recordItem: ExecutionRecorderItem = {
type: 'screenshot',
ts: shotTime,
screenshot: uiContext.screenshotBase64,
timing: 'before Insight',
};
task.recorder = [recordItem];
const { uiContext } = taskContext;
assert(uiContext, 'uiContext is required for Insight task');

// try matching xpath
const elementFromXpath =
Expand Down Expand Up @@ -470,17 +421,17 @@ export class TaskExecutor {
subType: planType,
thought: plan.thought,
param: plan.param,
executor: async (param, context) => {
executor: async (param, taskContext) => {
debug(
'executing action',
planType,
param,
`context.element.center: ${context.element?.center}`,
`taskContext.element.center: ${taskContext.element?.center}`,
);

// Get context for actionSpace operations to ensure size info is available
const uiContext = await this.insight.contextRetrieverFn('locate');
context.task.uiContext = uiContext;
const uiContext = taskContext.uiContext;
assert(uiContext, 'uiContext is required for Action task');

requiredLocateFields.forEach((field) => {
assert(
Expand Down Expand Up @@ -523,7 +474,7 @@ export class TaskExecutor {

debug('calling action', action.name);
const actionFn = action.call.bind(this.interface);
await actionFn(param, context);
await actionFn(param, taskContext);
debug('called action', action.name);

try {
Expand Down Expand Up @@ -554,45 +505,28 @@ export class TaskExecutor {
}
}

const wrappedTasks = tasks.map(
(task: ExecutionTaskApply, index: number) => {
if (task.type === 'Action') {
return this.prependExecutorWithScreenshot(
task,
index === tasks.length - 1,
);
}
return task;
},
);

return {
tasks: wrappedTasks,
tasks,
};
}

private async setupPlanningContext(executorContext: ExecutorContext) {
const shotTime = Date.now();
const uiContext = await this.insight.contextRetrieverFn('locate');
const recordItem: ExecutionRecorderItem = {
type: 'screenshot',
ts: shotTime,
screenshot: uiContext.screenshotBase64,
timing: 'before Planning',
};

executorContext.task.recorder = [recordItem];
(executorContext.task as ExecutionTaskPlanning).uiContext = uiContext;
const uiContext = executorContext.uiContext;
assert(uiContext, 'uiContext is required for Planning task');

return {
uiContext,
};
}

async loadYamlFlowAsPlanning(userInstruction: string, yamlString: string) {
const taskExecutor = new Executor(taskTitleStr('Action', userInstruction), {
onTaskStart: this.onTaskStartCallback,
});
const taskExecutor = new Executor(
taskTitleStr('Action', userInstruction),
() => Promise.resolve(this.insight.contextRetrieverFn()),
{
onTaskStart: this.onTaskStartCallback,
},
);

const task: ExecutionTaskPlanningApply = {
type: 'Planning',
Expand Down Expand Up @@ -741,9 +675,13 @@ export class TaskExecutor {
plans: PlanningAction[],
modelConfig: IModelConfig,
): Promise<ExecutionResult> {
const taskExecutor = new Executor(title, {
onTaskStart: this.onTaskStartCallback,
});
const taskExecutor = new Executor(
title,
() => Promise.resolve(this.insight.contextRetrieverFn()),
{
onTaskStart: this.onTaskStartCallback,
},
);
const { tasks } = await this.convertPlanToExecutable(plans, modelConfig);
await taskExecutor.append(tasks);
const result = await taskExecutor.flush();
Expand Down Expand Up @@ -781,9 +719,13 @@ export class TaskExecutor {
> {
this.conversationHistory.reset();

const taskExecutor = new Executor(taskTitleStr('Action', userPrompt), {
onTaskStart: this.onTaskStartCallback,
});
const taskExecutor = new Executor(
taskTitleStr('Action', userPrompt),
() => Promise.resolve(this.insight.contextRetrieverFn()),
{
onTaskStart: this.onTaskStartCallback,
},
);

let replanCount = 0;
const yamlFlow: MidsceneYamlFlowItem[] = [];
Expand Down Expand Up @@ -891,17 +833,8 @@ export class TaskExecutor {
this.insight.onceDumpUpdatedFn = dumpCollector;

// Get context for query operations
const shotTime = Date.now();
const uiContext = await this.insight.contextRetrieverFn('extract');
task.uiContext = uiContext;

const recordItem: ExecutionRecorderItem = {
type: 'screenshot',
ts: shotTime,
screenshot: uiContext.screenshotBase64,
timing: 'before Extract',
};
task.recorder = [recordItem];
const uiContext = taskContext.uiContext;
assert(uiContext, 'uiContext is required for Query task');

const ifTypeRestricted = type !== 'Query';
let demandInput = demand;
Expand Down Expand Up @@ -965,6 +898,7 @@ export class TaskExecutor {
type,
typeof demand === 'string' ? demand : JSON.stringify(demand),
),
() => Promise.resolve(this.insight.contextRetrieverFn()),
{
onTaskStart: this.onTaskStartCallback,
},
Expand All @@ -978,7 +912,7 @@ export class TaskExecutor {
multimodalPrompt,
);

await taskExecutor.append(this.prependExecutorWithScreenshot(queryTask));
await taskExecutor.append(queryTask);
const result = await taskExecutor.flush();

if (!result) {
Expand Down Expand Up @@ -1012,7 +946,7 @@ export class TaskExecutor {
[errorPlan],
modelConfig,
);
await taskExecutor.append(this.prependExecutorWithScreenshot(tasks[0]));
await taskExecutor.append(tasks[0]);
await taskExecutor.flush();

return {
Expand All @@ -1035,7 +969,7 @@ export class TaskExecutor {
modelConfig,
);

return this.prependExecutorWithScreenshot(sleepTasks[0]);
return sleepTasks[0];
}

async waitFor(
Expand All @@ -1046,9 +980,13 @@ export class TaskExecutor {
const { textPrompt, multimodalPrompt } = parsePrompt(assertion);

const description = `waitFor: ${textPrompt}`;
const taskExecutor = new Executor(taskTitleStr('WaitFor', description), {
onTaskStart: this.onTaskStartCallback,
});
const taskExecutor = new Executor(
taskTitleStr('WaitFor', description),
() => Promise.resolve(this.insight.contextRetrieverFn()),
{
onTaskStart: this.onTaskStartCallback,
},
);
const { timeoutMs, checkIntervalMs } = opt;

assert(assertion, 'No assertion for waitFor');
Expand All @@ -1075,7 +1013,7 @@ export class TaskExecutor {
multimodalPrompt,
);

await taskExecutor.append(this.prependExecutorWithScreenshot(queryTask));
await taskExecutor.append(queryTask);
const result = (await taskExecutor.flush()) as
| {
output: boolean;
Expand Down
Loading