Manager.php 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866
  1. <?php
  2. declare(strict_types=1);
  3. /**
  4. * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
  5. * SPDX-License-Identifier: AGPL-3.0-or-later
  6. */
  7. namespace OC\TaskProcessing;
  8. use OC\AppFramework\Bootstrap\Coordinator;
  9. use OC\Files\SimpleFS\SimpleFile;
  10. use OC\TaskProcessing\Db\TaskMapper;
  11. use OCP\AppFramework\Db\DoesNotExistException;
  12. use OCP\AppFramework\Db\MultipleObjectsReturnedException;
  13. use OCP\BackgroundJob\IJobList;
  14. use OCP\EventDispatcher\IEventDispatcher;
  15. use OCP\Files\AppData\IAppDataFactory;
  16. use OCP\Files\File;
  17. use OCP\Files\GenericFileException;
  18. use OCP\Files\IAppData;
  19. use OCP\Files\IRootFolder;
  20. use OCP\Files\NotPermittedException;
  21. use OCP\Files\SimpleFS\ISimpleFile;
  22. use OCP\IL10N;
  23. use OCP\IServerContainer;
  24. use OCP\L10N\IFactory;
  25. use OCP\Lock\LockedException;
  26. use OCP\SpeechToText\ISpeechToTextProvider;
  27. use OCP\SpeechToText\ISpeechToTextProviderWithId;
  28. use OCP\TaskProcessing\EShapeType;
  29. use OCP\TaskProcessing\Events\TaskFailedEvent;
  30. use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
  31. use OCP\TaskProcessing\Exception\NotFoundException;
  32. use OCP\TaskProcessing\Exception\ProcessingException;
  33. use OCP\TaskProcessing\Exception\UnauthorizedException;
  34. use OCP\TaskProcessing\Exception\ValidationException;
  35. use OCP\TaskProcessing\IManager;
  36. use OCP\TaskProcessing\IProvider;
  37. use OCP\TaskProcessing\ISynchronousProvider;
  38. use OCP\TaskProcessing\ITaskType;
  39. use OCP\TaskProcessing\ShapeDescriptor;
  40. use OCP\TaskProcessing\Task;
  41. use OCP\TaskProcessing\TaskTypes\AudioToText;
  42. use OCP\TaskProcessing\TaskTypes\TextToImage;
  43. use OCP\TaskProcessing\TaskTypes\TextToText;
  44. use OCP\TaskProcessing\TaskTypes\TextToTextHeadline;
  45. use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
  46. use OCP\TaskProcessing\TaskTypes\TextToTextTopics;
  47. use Psr\Log\LoggerInterface;
  48. class Manager implements IManager {
  49. public const LEGACY_PREFIX_TEXTPROCESSING = 'legacy:TextProcessing:';
  50. public const LEGACY_PREFIX_TEXTTOIMAGE = 'legacy:TextToImage:';
  51. public const LEGACY_PREFIX_SPEECHTOTEXT = 'legacy:SpeechToText:';
  52. /** @var list<IProvider>|null */
  53. private ?array $providers = null;
  54. /** @var array<string,array{name: string, description: string, inputShape: array<string, ShapeDescriptor>, optionalInputShape: array<string, ShapeDescriptor>, outputShape: array<string, ShapeDescriptor>, optionalOutputShape: array<string, ShapeDescriptor>}>|null */
  55. private ?array $availableTaskTypes = null;
  56. private IAppData $appData;
  57. public function __construct(
  58. private Coordinator $coordinator,
  59. private IServerContainer $serverContainer,
  60. private LoggerInterface $logger,
  61. private TaskMapper $taskMapper,
  62. private IJobList $jobList,
  63. private IEventDispatcher $dispatcher,
  64. IAppDataFactory $appDataFactory,
  65. private IRootFolder $rootFolder,
  66. private \OCP\TextProcessing\IManager $textProcessingManager,
  67. private \OCP\TextToImage\IManager $textToImageManager,
  68. private \OCP\SpeechToText\ISpeechToTextManager $speechToTextManager,
  69. private \OCP\Share\IManager $shareManager,
  70. ) {
  71. $this->appData = $appDataFactory->get('core');
  72. }
  73. /**
  74. * @return IProvider[]
  75. */
  76. private function _getTextProcessingProviders(): array {
  77. $oldProviders = $this->textProcessingManager->getProviders();
  78. $newProviders = [];
  79. foreach ($oldProviders as $oldProvider) {
  80. $provider = new class($oldProvider) implements IProvider, ISynchronousProvider {
  81. private \OCP\TextProcessing\IProvider $provider;
  82. public function __construct(\OCP\TextProcessing\IProvider $provider) {
  83. $this->provider = $provider;
  84. }
  85. public function getId(): string {
  86. if ($this->provider instanceof \OCP\TextProcessing\IProviderWithId) {
  87. return $this->provider->getId();
  88. }
  89. return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider::class;
  90. }
  91. public function getName(): string {
  92. return $this->provider->getName();
  93. }
  94. public function getTaskTypeId(): string {
  95. return match ($this->provider->getTaskType()) {
  96. \OCP\TextProcessing\FreePromptTaskType::class => TextToText::ID,
  97. \OCP\TextProcessing\HeadlineTaskType::class => TextToTextHeadline::ID,
  98. \OCP\TextProcessing\TopicsTaskType::class => TextToTextTopics::ID,
  99. \OCP\TextProcessing\SummaryTaskType::class => TextToTextSummary::ID,
  100. default => Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider->getTaskType(),
  101. };
  102. }
  103. public function getExpectedRuntime(): int {
  104. if ($this->provider instanceof \OCP\TextProcessing\IProviderWithExpectedRuntime) {
  105. return $this->provider->getExpectedRuntime();
  106. }
  107. return 60;
  108. }
  109. public function getOptionalInputShape(): array {
  110. return [];
  111. }
  112. public function getOptionalOutputShape(): array {
  113. return [];
  114. }
  115. public function process(?string $userId, array $input, callable $reportProgress): array {
  116. if ($this->provider instanceof \OCP\TextProcessing\IProviderWithUserId) {
  117. $this->provider->setUserId($userId);
  118. }
  119. try {
  120. return ['output' => $this->provider->process($input['input'])];
  121. } catch(\RuntimeException $e) {
  122. throw new ProcessingException($e->getMessage(), 0, $e);
  123. }
  124. }
  125. };
  126. $newProviders[$provider->getId()] = $provider;
  127. }
  128. return $newProviders;
  129. }
  130. /**
  131. * @return ITaskType[]
  132. */
  133. private function _getTextProcessingTaskTypes(): array {
  134. $oldProviders = $this->textProcessingManager->getProviders();
  135. $newTaskTypes = [];
  136. foreach ($oldProviders as $oldProvider) {
  137. // These are already implemented in the TaskProcessing realm
  138. if (in_array($oldProvider->getTaskType(), [
  139. \OCP\TextProcessing\FreePromptTaskType::class,
  140. \OCP\TextProcessing\HeadlineTaskType::class,
  141. \OCP\TextProcessing\TopicsTaskType::class,
  142. \OCP\TextProcessing\SummaryTaskType::class
  143. ], true)) {
  144. continue;
  145. }
  146. $taskType = new class($oldProvider->getTaskType()) implements ITaskType {
  147. private string $oldTaskTypeClass;
  148. private \OCP\TextProcessing\ITaskType $oldTaskType;
  149. private IL10N $l;
  150. public function __construct(string $oldTaskTypeClass) {
  151. $this->oldTaskTypeClass = $oldTaskTypeClass;
  152. $this->oldTaskType = \OCP\Server::get($oldTaskTypeClass);
  153. $this->l = \OCP\Server::get(IFactory::class)->get('core');
  154. }
  155. public function getId(): string {
  156. return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->oldTaskTypeClass;
  157. }
  158. public function getName(): string {
  159. return $this->oldTaskType->getName();
  160. }
  161. public function getDescription(): string {
  162. return $this->oldTaskType->getDescription();
  163. }
  164. public function getInputShape(): array {
  165. return ['input' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)];
  166. }
  167. public function getOutputShape(): array {
  168. return ['output' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)];
  169. }
  170. };
  171. $newTaskTypes[$taskType->getId()] = $taskType;
  172. }
  173. return $newTaskTypes;
  174. }
  175. /**
  176. * @return IProvider[]
  177. */
  178. private function _getTextToImageProviders(): array {
  179. $oldProviders = $this->textToImageManager->getProviders();
  180. $newProviders = [];
  181. foreach ($oldProviders as $oldProvider) {
  182. $newProvider = new class($oldProvider, $this->appData) implements IProvider, ISynchronousProvider {
  183. private \OCP\TextToImage\IProvider $provider;
  184. private IAppData $appData;
  185. public function __construct(\OCP\TextToImage\IProvider $provider, IAppData $appData) {
  186. $this->provider = $provider;
  187. $this->appData = $appData;
  188. }
  189. public function getId(): string {
  190. return Manager::LEGACY_PREFIX_TEXTTOIMAGE . $this->provider->getId();
  191. }
  192. public function getName(): string {
  193. return $this->provider->getName();
  194. }
  195. public function getTaskTypeId(): string {
  196. return TextToImage::ID;
  197. }
  198. public function getExpectedRuntime(): int {
  199. return $this->provider->getExpectedRuntime();
  200. }
  201. public function getOptionalInputShape(): array {
  202. return [];
  203. }
  204. public function getOptionalOutputShape(): array {
  205. return [];
  206. }
  207. public function process(?string $userId, array $input, callable $reportProgress): array {
  208. try {
  209. $folder = $this->appData->getFolder('text2image');
  210. } catch(\OCP\Files\NotFoundException) {
  211. $folder = $this->appData->newFolder('text2image');
  212. }
  213. $resources = [];
  214. $files = [];
  215. for ($i = 0; $i < $input['numberOfImages']; $i++) {
  216. $file = $folder->newFile(time() . '-' . rand(1, 100000) . '-' . $i);
  217. $files[] = $file;
  218. $resource = $file->write();
  219. if ($resource !== false && $resource !== true && is_resource($resource)) {
  220. $resources[] = $resource;
  221. } else {
  222. throw new ProcessingException('Text2Image generation using provider "' . $this->getName() . '" failed: Couldn\'t open file to write.');
  223. }
  224. }
  225. if ($this->provider instanceof \OCP\TextToImage\IProviderWithUserId) {
  226. $this->provider->setUserId($userId);
  227. }
  228. try {
  229. $this->provider->generate($input['input'], $resources);
  230. } catch (\RuntimeException $e) {
  231. throw new ProcessingException($e->getMessage(), 0, $e);
  232. }
  233. for ($i = 0; $i < $input['numberOfImages']; $i++) {
  234. if (is_resource($resources[$i])) {
  235. // If $resource hasn't been closed yet, we'll do that here
  236. fclose($resources[$i]);
  237. }
  238. }
  239. return ['images' => array_map(fn (ISimpleFile $file) => $file->getContent(), $files)];
  240. }
  241. };
  242. $newProviders[$newProvider->getId()] = $newProvider;
  243. }
  244. return $newProviders;
  245. }
  246. /**
  247. * @return IProvider[]
  248. */
  249. private function _getSpeechToTextProviders(): array {
  250. $oldProviders = $this->speechToTextManager->getProviders();
  251. $newProviders = [];
  252. foreach ($oldProviders as $oldProvider) {
  253. $newProvider = new class($oldProvider, $this->rootFolder, $this->appData) implements IProvider, ISynchronousProvider {
  254. private ISpeechToTextProvider $provider;
  255. private IAppData $appData;
  256. private IRootFolder $rootFolder;
  257. public function __construct(ISpeechToTextProvider $provider, IRootFolder $rootFolder, IAppData $appData) {
  258. $this->provider = $provider;
  259. $this->rootFolder = $rootFolder;
  260. $this->appData = $appData;
  261. }
  262. public function getId(): string {
  263. if ($this->provider instanceof ISpeechToTextProviderWithId) {
  264. return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider->getId();
  265. }
  266. return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider::class;
  267. }
  268. public function getName(): string {
  269. return $this->provider->getName();
  270. }
  271. public function getTaskTypeId(): string {
  272. return AudioToText::ID;
  273. }
  274. public function getExpectedRuntime(): int {
  275. return 60;
  276. }
  277. public function getOptionalInputShape(): array {
  278. return [];
  279. }
  280. public function getOptionalOutputShape(): array {
  281. return [];
  282. }
  283. public function process(?string $userId, array $input, callable $reportProgress): array {
  284. try {
  285. $result = $this->provider->transcribeFile($input['input']);
  286. } catch (\RuntimeException $e) {
  287. throw new ProcessingException($e->getMessage(), 0, $e);
  288. }
  289. return ['output' => $result];
  290. }
  291. };
  292. $newProviders[$newProvider->getId()] = $newProvider;
  293. }
  294. return $newProviders;
  295. }
  296. /**
  297. * @return IProvider[]
  298. */
  299. private function _getProviders(): array {
  300. $context = $this->coordinator->getRegistrationContext();
  301. if ($context === null) {
  302. return [];
  303. }
  304. $providers = [];
  305. foreach ($context->getTaskProcessingProviders() as $providerServiceRegistration) {
  306. $class = $providerServiceRegistration->getService();
  307. try {
  308. /** @var IProvider $provider */
  309. $provider = $this->serverContainer->get($class);
  310. if (isset($providers[$provider->getId()])) {
  311. $this->logger->warning('Task processing provider ' . $class . ' is using ID ' . $provider->getId() . ' which is already used by ' . $providers[$provider->getId()]::class);
  312. }
  313. $providers[$provider->getId()] = $provider;
  314. } catch (\Throwable $e) {
  315. $this->logger->error('Failed to load task processing provider ' . $class, [
  316. 'exception' => $e,
  317. ]);
  318. }
  319. }
  320. $providers += $this->_getTextProcessingProviders() + $this->_getTextToImageProviders() + $this->_getSpeechToTextProviders();
  321. return $providers;
  322. }
  323. /**
  324. * @return ITaskType[]
  325. */
  326. private function _getTaskTypes(): array {
  327. $context = $this->coordinator->getRegistrationContext();
  328. if ($context === null) {
  329. return [];
  330. }
  331. // Default task types
  332. $taskTypes = [
  333. \OCP\TaskProcessing\TaskTypes\TextToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToText::class),
  334. \OCP\TaskProcessing\TaskTypes\TextToTextTopics::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextTopics::class),
  335. \OCP\TaskProcessing\TaskTypes\TextToTextHeadline::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextHeadline::class),
  336. \OCP\TaskProcessing\TaskTypes\TextToTextSummary::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextSummary::class),
  337. \OCP\TaskProcessing\TaskTypes\TextToImage::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToImage::class),
  338. \OCP\TaskProcessing\TaskTypes\AudioToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToText::class),
  339. ];
  340. foreach ($context->getTaskProcessingTaskTypes() as $providerServiceRegistration) {
  341. $class = $providerServiceRegistration->getService();
  342. try {
  343. /** @var ITaskType $provider */
  344. $taskType = $this->serverContainer->get($class);
  345. if (isset($taskTypes[$taskType->getId()])) {
  346. $this->logger->warning('Task processing task type ' . $class . ' is using ID ' . $taskType->getId() . ' which is already used by ' . $taskTypes[$taskType->getId()]::class);
  347. }
  348. $taskTypes[$taskType->getId()] = $taskType;
  349. } catch (\Throwable $e) {
  350. $this->logger->error('Failed to load task processing task type ' . $class, [
  351. 'exception' => $e,
  352. ]);
  353. }
  354. }
  355. $taskTypes += $this->_getTextProcessingTaskTypes();
  356. return $taskTypes;
  357. }
  358. /**
  359. * @param string $taskType
  360. * @return IProvider
  361. * @throws \OCP\TaskProcessing\Exception\Exception
  362. */
  363. private function _getPreferredProvider(string $taskType) {
  364. $providers = $this->getProviders();
  365. foreach ($providers as $provider) {
  366. if ($provider->getTaskTypeId() === $taskType) {
  367. return $provider;
  368. }
  369. }
  370. throw new \OCP\TaskProcessing\Exception\Exception('No matching provider found');
  371. }
  372. /**
  373. * @param ShapeDescriptor[] $spec
  374. * @param array $io
  375. * @return void
  376. * @throws ValidationException
  377. */
  378. private function validateInput(array $spec, array $io, bool $optional = false): void {
  379. foreach ($spec as $key => $descriptor) {
  380. $type = $descriptor->getShapeType();
  381. if (!isset($io[$key])) {
  382. if ($optional) {
  383. continue;
  384. }
  385. throw new ValidationException('Missing key: "' . $key . '"');
  386. }
  387. try {
  388. $type->validateInput($io[$key]);
  389. } catch (ValidationException $e) {
  390. throw new ValidationException('Failed to validate input key "' . $key . '": ' . $e->getMessage());
  391. }
  392. }
  393. }
  394. /**
  395. * @param ShapeDescriptor[] $spec
  396. * @param array $io
  397. * @param bool $optional
  398. * @return void
  399. * @throws ValidationException
  400. */
  401. private function validateOutput(array $spec, array $io, bool $optional = false): void {
  402. foreach ($spec as $key => $descriptor) {
  403. $type = $descriptor->getShapeType();
  404. if (!isset($io[$key])) {
  405. if ($optional) {
  406. continue;
  407. }
  408. throw new ValidationException('Missing key: "' . $key . '"');
  409. }
  410. try {
  411. $type->validateOutput($io[$key]);
  412. } catch (ValidationException $e) {
  413. throw new ValidationException('Failed to validate output key "' . $key . '": ' . $e->getMessage());
  414. }
  415. }
  416. }
  417. /**
  418. * @param array<array-key, T> $array The array to filter
  419. * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep
  420. * @return array<array-key, T>
  421. * @psalm-template T
  422. */
  423. private function removeSuperfluousArrayKeys(array $array, ...$specs): array {
  424. $keys = array_unique(array_reduce($specs, fn ($carry, $spec) => $carry + array_keys($spec), []));
  425. $values = array_map(fn (string $key) => $array[$key], $keys);
  426. return array_combine($keys, $values);
  427. }
  428. public function hasProviders(): bool {
  429. return count($this->getProviders()) !== 0;
  430. }
  431. public function getProviders(): array {
  432. if ($this->providers === null) {
  433. $this->providers = $this->_getProviders();
  434. }
  435. return $this->providers;
  436. }
  437. public function getAvailableTaskTypes(): array {
  438. if ($this->availableTaskTypes === null) {
  439. $taskTypes = $this->_getTaskTypes();
  440. $providers = $this->getProviders();
  441. $availableTaskTypes = [];
  442. foreach ($providers as $provider) {
  443. if (!isset($taskTypes[$provider->getTaskTypeId()])) {
  444. continue;
  445. }
  446. $taskType = $taskTypes[$provider->getTaskTypeId()];
  447. $availableTaskTypes[$provider->getTaskTypeId()] = [
  448. 'name' => $taskType->getName(),
  449. 'description' => $taskType->getDescription(),
  450. 'inputShape' => $taskType->getInputShape(),
  451. 'optionalInputShape' => $provider->getOptionalInputShape(),
  452. 'outputShape' => $taskType->getOutputShape(),
  453. 'optionalOutputShape' => $provider->getOptionalOutputShape(),
  454. ];
  455. }
  456. $this->availableTaskTypes = $availableTaskTypes;
  457. }
  458. return $this->availableTaskTypes;
  459. }
  460. public function canHandleTask(Task $task): bool {
  461. return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]);
  462. }
  463. public function scheduleTask(Task $task): void {
  464. if (!$this->canHandleTask($task)) {
  465. throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId());
  466. }
  467. $taskTypes = $this->getAvailableTaskTypes();
  468. $inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape'];
  469. $optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape'];
  470. // validate input
  471. $this->validateInput($inputShape, $task->getInput());
  472. $this->validateInput($optionalInputShape, $task->getInput(), true);
  473. // authenticate access to mentioned files
  474. $ids = [];
  475. foreach ($inputShape + $optionalInputShape as $key => $descriptor) {
  476. if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) {
  477. /** @var list<int>|int $inputSlot */
  478. $inputSlot = $task->getInput()[$key];
  479. if (is_array($inputSlot)) {
  480. $ids += $inputSlot;
  481. } else {
  482. $ids[] = $inputSlot;
  483. }
  484. }
  485. }
  486. foreach ($ids as $fileId) {
  487. $node = $this->rootFolder->getFirstNodeById($fileId);
  488. if ($node === null) {
  489. $node = $this->rootFolder->getFirstNodeByIdInPath($fileId, '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  490. if ($node === null) {
  491. throw new ValidationException('Could not find file ' . $fileId);
  492. }
  493. }
  494. /** @var array{users:array<string,array{node_id:int, node_path: string}>, remote: array<string,array{node_id:int, node_path: string}>, mail: array<string,array{node_id:int, node_path: string}>} $accessList */
  495. $accessList = $this->shareManager->getAccessList($node, true, true);
  496. $userIds = array_map(fn ($id) => strval($id), array_keys($accessList['users']));
  497. if (!in_array($task->getUserId(), $userIds)) {
  498. throw new UnauthorizedException('User ' . $task->getUserId() . ' does not have access to file ' . $fileId);
  499. }
  500. }
  501. // remove superfluous keys and set input
  502. $task->setInput($this->removeSuperfluousArrayKeys($task->getInput(), $inputShape, $optionalInputShape));
  503. $task->setStatus(Task::STATUS_SCHEDULED);
  504. $provider = $this->_getPreferredProvider($task->getTaskTypeId());
  505. // calculate expected completion time
  506. $completionExpectedAt = new \DateTime('now');
  507. $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
  508. $task->setCompletionExpectedAt($completionExpectedAt);
  509. // create a db entity and insert into db table
  510. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  511. $this->taskMapper->insert($taskEntity);
  512. // make sure the scheduler knows the id
  513. $task->setId($taskEntity->getId());
  514. // schedule synchronous job if the provider is synchronous
  515. if ($provider instanceof ISynchronousProvider) {
  516. $this->jobList->add(SynchronousBackgroundJob::class, null);
  517. }
  518. }
  519. public function deleteTask(Task $task): void {
  520. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  521. $this->taskMapper->delete($taskEntity);
  522. }
  523. public function getTask(int $id): Task {
  524. try {
  525. $taskEntity = $this->taskMapper->find($id);
  526. return $taskEntity->toPublicTask();
  527. } catch (DoesNotExistException $e) {
  528. throw new NotFoundException('Couldn\'t find task with id ' . $id, 0, $e);
  529. } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) {
  530. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  531. } catch (\JsonException $e) {
  532. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e);
  533. }
  534. }
  535. public function cancelTask(int $id): void {
  536. $task = $this->getTask($id);
  537. if ($task->getStatus() !== Task::STATUS_SCHEDULED && $task->getStatus() !== Task::STATUS_RUNNING) {
  538. return;
  539. }
  540. $task->setStatus(Task::STATUS_CANCELLED);
  541. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  542. try {
  543. $this->taskMapper->update($taskEntity);
  544. } catch (\OCP\DB\Exception $e) {
  545. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  546. }
  547. }
  548. public function setTaskProgress(int $id, float $progress): bool {
  549. // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently
  550. $task = $this->getTask($id);
  551. if ($task->getStatus() === Task::STATUS_CANCELLED) {
  552. return false;
  553. }
  554. $task->setStatus(Task::STATUS_RUNNING);
  555. $task->setProgress($progress);
  556. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  557. try {
  558. $this->taskMapper->update($taskEntity);
  559. } catch (\OCP\DB\Exception $e) {
  560. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  561. }
  562. return true;
  563. }
  564. public function setTaskResult(int $id, ?string $error, ?array $result): void {
  565. // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently
  566. $task = $this->getTask($id);
  567. if ($task->getStatus() === Task::STATUS_CANCELLED) {
  568. $this->logger->info('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' finished but was cancelled in the mean time. Moving on without storing result.');
  569. return;
  570. }
  571. if ($error !== null) {
  572. $task->setStatus(Task::STATUS_FAILED);
  573. $task->setErrorMessage($error);
  574. $this->logger->warning('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' failed with the following message: ' . $error);
  575. } elseif ($result !== null) {
  576. $taskTypes = $this->getAvailableTaskTypes();
  577. $outputShape = $taskTypes[$task->getTaskTypeId()]['outputShape'];
  578. $optionalOutputShape = $taskTypes[$task->getTaskTypeId()]['optionalOutputShape'];
  579. try {
  580. // validate output
  581. $this->validateOutput($outputShape, $result);
  582. $this->validateOutput($optionalOutputShape, $result, true);
  583. $output = $this->removeSuperfluousArrayKeys($result, $outputShape, $optionalOutputShape);
  584. // extract raw data and put it in files, replace it with file ids
  585. $output = $this->encapsulateOutputFileData($output, $outputShape, $optionalOutputShape);
  586. $task->setOutput($output);
  587. $task->setProgress(1);
  588. $task->setStatus(Task::STATUS_SUCCESSFUL);
  589. } catch (ValidationException $e) {
  590. $task->setProgress(1);
  591. $task->setStatus(Task::STATUS_FAILED);
  592. $error = 'The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec';
  593. $task->setErrorMessage($error);
  594. $this->logger->error($error, ['exception' => $e]);
  595. } catch (NotPermittedException $e) {
  596. $task->setProgress(1);
  597. $task->setStatus(Task::STATUS_FAILED);
  598. $error = 'The task was processed successfully but storing the output in a file failed';
  599. $task->setErrorMessage($error);
  600. $this->logger->error($error, ['exception' => $e]);
  601. }
  602. }
  603. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  604. try {
  605. $this->taskMapper->update($taskEntity);
  606. } catch (\OCP\DB\Exception $e) {
  607. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  608. }
  609. if ($task->getStatus() === Task::STATUS_SUCCESSFUL) {
  610. $event = new TaskSuccessfulEvent($task);
  611. } else {
  612. $event = new TaskFailedEvent($task, $error);
  613. }
  614. $this->dispatcher->dispatchTyped($event);
  615. }
  616. public function getNextScheduledTask(?string $taskTypeId = null): Task {
  617. try {
  618. $taskEntity = $this->taskMapper->findOldestScheduledByType($taskTypeId);
  619. return $taskEntity->toPublicTask();
  620. } catch (DoesNotExistException $e) {
  621. throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', 0, $e);
  622. } catch (\OCP\DB\Exception $e) {
  623. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  624. } catch (\JsonException $e) {
  625. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e);
  626. }
  627. }
  628. /**
  629. * Takes task input or output data and replaces fileIds with base64 data
  630. *
  631. * @param string|null $userId
  632. * @param array<array-key, list<numeric|string>|numeric|string> $input
  633. * @param ShapeDescriptor[] ...$specs the specs
  634. * @return array<array-key, list<File|numeric|string>|numeric|string|File>
  635. * @throws GenericFileException
  636. * @throws LockedException
  637. * @throws NotPermittedException
  638. * @throws ValidationException
  639. */
  640. public function fillInputFileData(?string $userId, array $input, ...$specs): array {
  641. if ($userId !== null) {
  642. \OC_Util::setupFS($userId);
  643. }
  644. $newInputOutput = [];
  645. $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []);
  646. foreach($spec as $key => $descriptor) {
  647. $type = $descriptor->getShapeType();
  648. if (!isset($input[$key])) {
  649. continue;
  650. }
  651. if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) {
  652. $newInputOutput[$key] = $input[$key];
  653. continue;
  654. }
  655. if ($type->value < 10) {
  656. $node = $this->rootFolder->getFirstNodeById((int)$input[$key]);
  657. if ($node === null) {
  658. $node = $this->rootFolder->getFirstNodeByIdInPath((int)$input[$key], '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  659. if (!$node instanceof File) {
  660. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  661. }
  662. } elseif (!$node instanceof File) {
  663. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  664. }
  665. // TODO: Validate if userId has access to this file
  666. $newInputOutput[$key] = $node;
  667. } else {
  668. $newInputOutput[$key] = [];
  669. foreach ($input[$key] as $item) {
  670. $node = $this->rootFolder->getFirstNodeById((int)$item);
  671. if ($node === null) {
  672. $node = $this->rootFolder->getFirstNodeByIdInPath((int)$item, '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  673. if (!$node instanceof File) {
  674. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  675. }
  676. } elseif (!$node instanceof File) {
  677. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  678. }
  679. // TODO: Validate if userId has access to this file
  680. $newInputOutput[$key][] = $node;
  681. }
  682. }
  683. }
  684. return $newInputOutput;
  685. }
  686. public function getUserTask(int $id, ?string $userId): Task {
  687. try {
  688. $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId);
  689. return $taskEntity->toPublicTask();
  690. } catch (DoesNotExistException $e) {
  691. throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', 0, $e);
  692. } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) {
  693. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  694. } catch (\JsonException $e) {
  695. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e);
  696. }
  697. }
  698. public function getUserTasks(?string $userId, ?string $taskTypeId = null, ?string $customId = null): array {
  699. try {
  700. $taskEntities = $this->taskMapper->findByUserAndTaskType($userId, $taskTypeId, $customId);
  701. return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities);
  702. } catch (\OCP\DB\Exception $e) {
  703. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the tasks', 0, $e);
  704. } catch (\JsonException $e) {
  705. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the tasks', 0, $e);
  706. }
  707. }
  708. public function getUserTasksByApp(?string $userId, string $appId, ?string $customId = null): array {
  709. try {
  710. $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $customId);
  711. return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities);
  712. } catch (\OCP\DB\Exception $e) {
  713. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding a task', 0, $e);
  714. } catch (\JsonException $e) {
  715. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding a task', 0, $e);
  716. }
  717. }
  718. /**
  719. *Takes task input or output and replaces base64 data with file ids
  720. *
  721. * @param array $output
  722. * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep
  723. * @return array
  724. * @throws NotPermittedException
  725. */
  726. public function encapsulateOutputFileData(array $output, ...$specs): array {
  727. $newOutput = [];
  728. try {
  729. $folder = $this->appData->getFolder('TaskProcessing');
  730. } catch (\OCP\Files\NotFoundException) {
  731. $folder = $this->appData->newFolder('TaskProcessing');
  732. }
  733. $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []);
  734. foreach($spec as $key => $descriptor) {
  735. $type = $descriptor->getShapeType();
  736. if (!isset($output[$key])) {
  737. continue;
  738. }
  739. if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) {
  740. $newOutput[$key] = $output[$key];
  741. continue;
  742. }
  743. if ($type->value < 10) {
  744. /** @var SimpleFile $file */
  745. $file = $folder->newFile((string) rand(0, 10000000), $output[$key]);
  746. $newOutput[$key] = $file->getId(); // polymorphic call to SimpleFile
  747. } else {
  748. $newOutput = [];
  749. foreach ($output[$key] as $item) {
  750. /** @var SimpleFile $file */
  751. $file = $folder->newFile((string) rand(0, 10000000), $item);
  752. $newOutput[$key][] = $file->getId();
  753. }
  754. }
  755. }
  756. return $newOutput;
  757. }
  758. /**
  759. * @param Task $task
  760. * @return array<array-key, list<numeric|string|File>|numeric|string|File>
  761. * @throws GenericFileException
  762. * @throws LockedException
  763. * @throws NotPermittedException
  764. * @throws ValidationException
  765. */
  766. public function prepareInputData(Task $task): array {
  767. $taskTypes = $this->getAvailableTaskTypes();
  768. $inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape'];
  769. $optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape'];
  770. $input = $task->getInput();
  771. // validate input, again for good measure (should have been validated in scheduleTask)
  772. $this->validateInput($inputShape, $input);
  773. $this->validateInput($optionalInputShape, $input, true);
  774. $input = $this->removeSuperfluousArrayKeys($input, $inputShape, $optionalInputShape);
  775. $input = $this->fillInputFileData($task->getUserId(), $input, $inputShape, $optionalInputShape);
  776. return $input;
  777. }
  778. }