Manager.php 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. <?php
  2. declare(strict_types=1);
  3. /**
  4. * SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
  5. * SPDX-License-Identifier: AGPL-3.0-or-later
  6. */
  7. namespace OC\TextProcessing;
  8. use OC\AppFramework\Bootstrap\Coordinator;
  9. use OC\TextProcessing\Db\Task as DbTask;
  10. use OC\TextProcessing\Db\TaskMapper;
  11. use OCP\AppFramework\Db\DoesNotExistException;
  12. use OCP\AppFramework\Db\MultipleObjectsReturnedException;
  13. use OCP\BackgroundJob\IJobList;
  14. use OCP\Common\Exception\NotFoundException;
  15. use OCP\DB\Exception;
  16. use OCP\IConfig;
  17. use OCP\IServerContainer;
  18. use OCP\PreConditionNotMetException;
  19. use OCP\TextProcessing\Exception\TaskFailureException;
  20. use OCP\TextProcessing\IManager;
  21. use OCP\TextProcessing\IProvider;
  22. use OCP\TextProcessing\IProviderWithExpectedRuntime;
  23. use OCP\TextProcessing\IProviderWithId;
  24. use OCP\TextProcessing\Task;
  25. use OCP\TextProcessing\Task as OCPTask;
  26. use Psr\Log\LoggerInterface;
  27. use RuntimeException;
  28. use Throwable;
  29. class Manager implements IManager {
  30. /** @var ?IProvider[] */
  31. private ?array $providers = null;
  32. public function __construct(
  33. private IServerContainer $serverContainer,
  34. private Coordinator $coordinator,
  35. private LoggerInterface $logger,
  36. private IJobList $jobList,
  37. private TaskMapper $taskMapper,
  38. private IConfig $config,
  39. ) {
  40. }
  41. public function getProviders(): array {
  42. $context = $this->coordinator->getRegistrationContext();
  43. if ($context === null) {
  44. return [];
  45. }
  46. if ($this->providers !== null) {
  47. return $this->providers;
  48. }
  49. $this->providers = [];
  50. foreach ($context->getTextProcessingProviders() as $providerServiceRegistration) {
  51. $class = $providerServiceRegistration->getService();
  52. try {
  53. $this->providers[$class] = $this->serverContainer->get($class);
  54. } catch (Throwable $e) {
  55. $this->logger->error('Failed to load Text processing provider ' . $class, [
  56. 'exception' => $e,
  57. ]);
  58. }
  59. }
  60. return $this->providers;
  61. }
  62. public function hasProviders(): bool {
  63. $context = $this->coordinator->getRegistrationContext();
  64. if ($context === null) {
  65. return false;
  66. }
  67. return count($context->getTextProcessingProviders()) > 0;
  68. }
  69. /**
  70. * @inheritDoc
  71. */
  72. public function getAvailableTaskTypes(): array {
  73. $tasks = [];
  74. foreach ($this->getProviders() as $provider) {
  75. $tasks[$provider->getTaskType()] = true;
  76. }
  77. return array_keys($tasks);
  78. }
  79. public function canHandleTask(OCPTask $task): bool {
  80. return in_array($task->getType(), $this->getAvailableTaskTypes());
  81. }
  82. /**
  83. * @inheritDoc
  84. */
  85. public function runTask(OCPTask $task): string {
  86. if (!$this->canHandleTask($task)) {
  87. throw new PreConditionNotMetException('No text processing provider is installed that can handle this task');
  88. }
  89. $providers = $this->getPreferredProviders($task);
  90. foreach ($providers as $provider) {
  91. try {
  92. $task->setStatus(OCPTask::STATUS_RUNNING);
  93. if ($provider instanceof IProviderWithExpectedRuntime) {
  94. $completionExpectedAt = new \DateTime('now');
  95. $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
  96. $task->setCompletionExpectedAt($completionExpectedAt);
  97. }
  98. if ($task->getId() === null) {
  99. $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
  100. $task->setId($taskEntity->getId());
  101. } else {
  102. $this->taskMapper->update(DbTask::fromPublicTask($task));
  103. }
  104. $output = $task->visitProvider($provider);
  105. $task->setOutput($output);
  106. $task->setStatus(OCPTask::STATUS_SUCCESSFUL);
  107. $this->taskMapper->update(DbTask::fromPublicTask($task));
  108. return $output;
  109. } catch (\Throwable $e) {
  110. $this->logger->error('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
  111. $task->setStatus(OCPTask::STATUS_FAILED);
  112. $this->taskMapper->update(DbTask::fromPublicTask($task));
  113. throw new TaskFailureException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e);
  114. }
  115. }
  116. $task->setStatus(OCPTask::STATUS_FAILED);
  117. $this->taskMapper->update(DbTask::fromPublicTask($task));
  118. throw new TaskFailureException('Could not run task');
  119. }
  120. /**
  121. * @inheritDoc
  122. */
  123. public function scheduleTask(OCPTask $task): void {
  124. if (!$this->canHandleTask($task)) {
  125. throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
  126. }
  127. $task->setStatus(OCPTask::STATUS_SCHEDULED);
  128. $providers = $this->getPreferredProviders($task);
  129. if (count($providers) === 0) {
  130. throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
  131. }
  132. [$provider,] = $providers;
  133. if ($provider instanceof IProviderWithExpectedRuntime) {
  134. $completionExpectedAt = new \DateTime('now');
  135. $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
  136. $task->setCompletionExpectedAt($completionExpectedAt);
  137. }
  138. $taskEntity = DbTask::fromPublicTask($task);
  139. $this->taskMapper->insert($taskEntity);
  140. $task->setId($taskEntity->getId());
  141. $this->jobList->add(TaskBackgroundJob::class, [
  142. 'taskId' => $task->getId()
  143. ]);
  144. }
  145. /**
  146. * @inheritDoc
  147. */
  148. public function runOrScheduleTask(OCPTask $task): bool {
  149. if (!$this->canHandleTask($task)) {
  150. throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
  151. }
  152. [$provider,] = $this->getPreferredProviders($task);
  153. $maxExecutionTime = (int) ini_get('max_execution_time');
  154. // Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time
  155. // or if the provider doesn't provide a getExpectedRuntime() method
  156. if (!$provider instanceof IProviderWithExpectedRuntime || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) {
  157. $this->scheduleTask($task);
  158. return false;
  159. }
  160. $this->runTask($task);
  161. return true;
  162. }
  163. /**
  164. * @inheritDoc
  165. */
  166. public function deleteTask(Task $task): void {
  167. $taskEntity = DbTask::fromPublicTask($task);
  168. $this->taskMapper->delete($taskEntity);
  169. $this->jobList->remove(TaskBackgroundJob::class, [
  170. 'taskId' => $task->getId()
  171. ]);
  172. }
  173. /**
  174. * Get a task from its id
  175. *
  176. * @param int $id The id of the task
  177. * @return OCPTask
  178. * @throws RuntimeException If the query failed
  179. * @throws NotFoundException If the task could not be found
  180. */
  181. public function getTask(int $id): OCPTask {
  182. try {
  183. $taskEntity = $this->taskMapper->find($id);
  184. return $taskEntity->toPublicTask();
  185. } catch (DoesNotExistException $e) {
  186. throw new NotFoundException('Could not find task with the provided id');
  187. } catch (MultipleObjectsReturnedException $e) {
  188. throw new RuntimeException('Could not uniquely identify task with given id', 0, $e);
  189. } catch (Exception $e) {
  190. throw new RuntimeException('Failure while trying to find task by id: ' . $e->getMessage(), 0, $e);
  191. }
  192. }
  193. /**
  194. * Get a task from its user id and task id
  195. * If userId is null, this can only get a task that was scheduled anonymously
  196. *
  197. * @param int $id The id of the task
  198. * @param string|null $userId The user id that scheduled the task
  199. * @return OCPTask
  200. * @throws RuntimeException If the query failed
  201. * @throws NotFoundException If the task could not be found
  202. */
  203. public function getUserTask(int $id, ?string $userId): OCPTask {
  204. try {
  205. $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId);
  206. return $taskEntity->toPublicTask();
  207. } catch (DoesNotExistException $e) {
  208. throw new NotFoundException('Could not find task with the provided id and user id');
  209. } catch (MultipleObjectsReturnedException $e) {
  210. throw new RuntimeException('Could not uniquely identify task with given id and user id', 0, $e);
  211. } catch (Exception $e) {
  212. throw new RuntimeException('Failure while trying to find task by id and user id: ' . $e->getMessage(), 0, $e);
  213. }
  214. }
  215. /**
  216. * Get a list of tasks scheduled by a specific user for a specific app
  217. * and optionally with a specific identifier.
  218. * This cannot be used to get anonymously scheduled tasks
  219. *
  220. * @param string $userId
  221. * @param string $appId
  222. * @param string|null $identifier
  223. * @return array
  224. */
  225. public function getUserTasksByApp(string $userId, string $appId, ?string $identifier = null): array {
  226. try {
  227. $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $identifier);
  228. return array_map(static function (DbTask $taskEntity) {
  229. return $taskEntity->toPublicTask();
  230. }, $taskEntities);
  231. } catch (Exception $e) {
  232. throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e);
  233. }
  234. }
  235. /**
  236. * @param OCPTask $task
  237. * @return IProvider[]
  238. */
  239. public function getPreferredProviders(OCPTask $task): array {
  240. $providers = $this->getProviders();
  241. $json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', '');
  242. if ($json !== '') {
  243. $preferences = json_decode($json, true);
  244. if (isset($preferences[$task->getType()])) {
  245. // If a preference for this task type is set, move the preferred provider to the start
  246. $provider = current(array_values(array_filter($providers, function ($provider) use ($preferences, $task) {
  247. if ($provider instanceof IProviderWithId) {
  248. return $provider->getId() === $preferences[$task->getType()];
  249. }
  250. return $provider::class === $preferences[$task->getType()];
  251. })));
  252. if ($provider !== false) {
  253. $providers = array_filter($providers, fn ($p) => $p !== $provider);
  254. array_unshift($providers, $provider);
  255. }
  256. }
  257. }
  258. return array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider)));
  259. }
  260. }