File size: 1,716 Bytes
8496edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from .base_agent import BaseAgent
from prompt.template import PROBLEM_ANALYSIS_PROMPT, PROBLEM_ANALYSIS_CRITIQUE_PROMPT, PROBLEM_ANALYSIS_IMPROVEMENT_PROMPT


class ProblemAnalysis(BaseAgent):
    def __init__(self, llm):
        super().__init__(llm)
    
    def analysis_actor(self, modeling_problem: str, user_prompt: str=''):
        prompt = PROBLEM_ANALYSIS_PROMPT.format(modeling_problem=modeling_problem, user_prompt=user_prompt).strip()
        return self.llm.generate(prompt)

    def analysis_critic(self, modeling_problem: str, problem_analysis: str):
        prompt = PROBLEM_ANALYSIS_CRITIQUE_PROMPT.format(modeling_problem=modeling_problem, problem_analysis=problem_analysis).strip()
        return self.llm.generate(prompt)

    def analysis_improvement(self, modeling_problem: str, problem_analysis: str, problem_analysis_critique: str, user_prompt: str=''):
        prompt = PROBLEM_ANALYSIS_IMPROVEMENT_PROMPT.format(modeling_problem=modeling_problem, problem_analysis=problem_analysis, problem_analysis_critique=problem_analysis_critique, user_prompt=user_prompt).strip()
        return self.llm.generate(prompt)

    def analysis(self, modeling_problem: str, round: int = 3, user_prompt: str = ''):
        problem_analysis = self.analysis_actor(modeling_problem, user_prompt)
        for i in range(round):
            print(f'Problem Analysis Round {i+1}')
            problem_analysis_critique = self.analysis_critic(modeling_problem, problem_analysis)
            problem_analysis_improvement = self.analysis_improvement(modeling_problem, problem_analysis, problem_analysis_critique, user_prompt)
            problem_analysis = problem_analysis_improvement
        return problem_analysis