169 lines
5.5 KiB
Python
169 lines
5.5 KiB
Python
"""Генератор расписаний с использованием GigaChat."""
|
|
import json
|
|
from typing import List, Optional
|
|
|
|
from models.gigachat_types import GigaChatMessage
|
|
from models.schedule import Schedule, Task
|
|
from prompts.schedule_prompts import SCHEDULE_GENERATION_PROMPT
|
|
|
|
from agents.gigachat_client import GigaChatClient
|
|
|
|
|
|
class ScheduleGenerator:
|
|
"""Генератор расписаний для детей с РАС."""
|
|
|
|
def __init__(self, gigachat: GigaChatClient):
|
|
self.gigachat = gigachat
|
|
|
|
async def generate(
|
|
self,
|
|
child_age: int,
|
|
preferences: List[str],
|
|
date: str,
|
|
existing_tasks: Optional[List[str]] = None,
|
|
model: str = "GigaChat-2-Pro",
|
|
) -> Schedule:
|
|
"""
|
|
Сгенерировать расписание.
|
|
|
|
Args:
|
|
child_age: Возраст ребенка
|
|
preferences: Предпочтения ребенка
|
|
date: Дата расписания
|
|
existing_tasks: Существующие задания для учета
|
|
model: Модель GigaChat
|
|
|
|
Returns:
|
|
Объект расписания
|
|
"""
|
|
preferences_str = ", ".join(preferences) if preferences else "не указаны"
|
|
|
|
prompt = SCHEDULE_GENERATION_PROMPT.format(
|
|
age=child_age,
|
|
preferences=preferences_str,
|
|
date=date,
|
|
)
|
|
|
|
if existing_tasks:
|
|
prompt += f"\n\nУчти существующие задания: {', '.join(existing_tasks)}"
|
|
|
|
# Используем более высокую температуру для разнообразия
|
|
response_text = await self.gigachat.chat(
|
|
message=prompt,
|
|
model=model,
|
|
temperature=0.8,
|
|
max_tokens=3000,
|
|
)
|
|
|
|
# Парсим JSON из ответа
|
|
schedule_data = self._parse_json_response(response_text)
|
|
|
|
# Создаем объект Schedule
|
|
tasks = [
|
|
Task(
|
|
title=task_data["title"],
|
|
description=task_data.get("description"),
|
|
duration_minutes=task_data["duration_minutes"],
|
|
category=task_data.get("category", "обучение"),
|
|
)
|
|
for task_data in schedule_data.get("tasks", [])
|
|
]
|
|
|
|
return Schedule(
|
|
title=schedule_data.get("title", f"Расписание на {date}"),
|
|
date=date,
|
|
tasks=tasks,
|
|
)
|
|
|
|
async def update(
|
|
self,
|
|
existing_schedule: Schedule,
|
|
user_request: str,
|
|
model: str = "GigaChat-2-Pro",
|
|
) -> Schedule:
|
|
"""
|
|
Обновить существующее расписание.
|
|
|
|
Args:
|
|
existing_schedule: Текущее расписание
|
|
user_request: Запрос на изменение
|
|
model: Модель GigaChat
|
|
|
|
Returns:
|
|
Обновленное расписание
|
|
"""
|
|
from prompts.schedule_prompts import SCHEDULE_UPDATE_PROMPT
|
|
|
|
schedule_json = existing_schedule.model_dump_json()
|
|
|
|
prompt = SCHEDULE_UPDATE_PROMPT.format(
|
|
existing_schedule=schedule_json,
|
|
user_request=user_request,
|
|
)
|
|
|
|
response_text = await self.gigachat.chat(
|
|
message=prompt,
|
|
model=model,
|
|
temperature=0.7,
|
|
max_tokens=3000,
|
|
)
|
|
|
|
schedule_data = self._parse_json_response(response_text)
|
|
|
|
tasks = [
|
|
Task(
|
|
title=task_data["title"],
|
|
description=task_data.get("description"),
|
|
duration_minutes=task_data["duration_minutes"],
|
|
category=task_data.get("category", "обучение"),
|
|
)
|
|
for task_data in schedule_data.get("tasks", [])
|
|
]
|
|
|
|
return Schedule(
|
|
id=existing_schedule.id,
|
|
title=schedule_data.get("title", existing_schedule.title),
|
|
date=existing_schedule.date,
|
|
tasks=tasks,
|
|
user_id=existing_schedule.user_id,
|
|
)
|
|
|
|
def _parse_json_response(self, response_text: str) -> dict:
|
|
"""
|
|
Извлечь JSON из ответа модели.
|
|
|
|
Args:
|
|
response_text: Текст ответа
|
|
|
|
Returns:
|
|
Распарсенный JSON
|
|
"""
|
|
# Пытаемся найти JSON в ответе
|
|
response_text = response_text.strip()
|
|
|
|
# Удаляем markdown код блоки если есть
|
|
if response_text.startswith("```json"):
|
|
response_text = response_text[7:]
|
|
if response_text.startswith("```"):
|
|
response_text = response_text[3:]
|
|
if response_text.endswith("```"):
|
|
response_text = response_text[:-3]
|
|
|
|
response_text = response_text.strip()
|
|
|
|
try:
|
|
return json.loads(response_text)
|
|
except json.JSONDecodeError:
|
|
# Если не удалось распарсить, пытаемся найти JSON объект в тексте
|
|
start_idx = response_text.find("{")
|
|
end_idx = response_text.rfind("}") + 1
|
|
|
|
if start_idx >= 0 and end_idx > start_idx:
|
|
try:
|
|
return json.loads(response_text[start_idx:end_idx])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
raise ValueError(f"Не удалось распарсить JSON из ответа: {response_text[:200]}")
|
|
|