Manager.php 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. <?php
  2. declare(strict_types=1);
  3. /**
  4. * @copyright Copyright (c) 2024 Marcel Klehr <mklehr@gmx.net>
  5. *
  6. * @author Marcel Klehr <mklehr@gmx.net>
  7. *
  8. * @license GNU AGPL version 3 or any later version
  9. *
  10. * This program is free software: you can redistribute it and/or modify
  11. * it under the terms of the GNU Affero General Public License as
  12. * published by the Free Software Foundation, either version 3 of the
  13. * License, or (at your option) any later version.
  14. *
  15. * This program is distributed in the hope that it will be useful,
  16. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  17. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  18. * GNU Affero General Public License for more details.
  19. *
  20. * You should have received a copy of the GNU Affero General Public License
  21. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  22. */
  23. namespace OC\TaskProcessing;
  24. use OC\AppFramework\Bootstrap\Coordinator;
  25. use OC\Files\SimpleFS\SimpleFile;
  26. use OC\TaskProcessing\Db\TaskMapper;
  27. use OCP\AppFramework\Db\DoesNotExistException;
  28. use OCP\AppFramework\Db\MultipleObjectsReturnedException;
  29. use OCP\BackgroundJob\IJobList;
  30. use OCP\EventDispatcher\IEventDispatcher;
  31. use OCP\Files\AppData\IAppDataFactory;
  32. use OCP\Files\File;
  33. use OCP\Files\GenericFileException;
  34. use OCP\Files\IAppData;
  35. use OCP\Files\IRootFolder;
  36. use OCP\Files\NotPermittedException;
  37. use OCP\Files\SimpleFS\ISimpleFile;
  38. use OCP\IL10N;
  39. use OCP\IServerContainer;
  40. use OCP\L10N\IFactory;
  41. use OCP\Lock\LockedException;
  42. use OCP\SpeechToText\ISpeechToTextProvider;
  43. use OCP\SpeechToText\ISpeechToTextProviderWithId;
  44. use OCP\TaskProcessing\EShapeType;
  45. use OCP\TaskProcessing\Events\TaskFailedEvent;
  46. use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
  47. use OCP\TaskProcessing\Exception\NotFoundException;
  48. use OCP\TaskProcessing\Exception\ProcessingException;
  49. use OCP\TaskProcessing\Exception\UnauthorizedException;
  50. use OCP\TaskProcessing\Exception\ValidationException;
  51. use OCP\TaskProcessing\IManager;
  52. use OCP\TaskProcessing\IProvider;
  53. use OCP\TaskProcessing\ISynchronousProvider;
  54. use OCP\TaskProcessing\ITaskType;
  55. use OCP\TaskProcessing\ShapeDescriptor;
  56. use OCP\TaskProcessing\Task;
  57. use OCP\TaskProcessing\TaskTypes\AudioToText;
  58. use OCP\TaskProcessing\TaskTypes\TextToImage;
  59. use OCP\TaskProcessing\TaskTypes\TextToText;
  60. use OCP\TaskProcessing\TaskTypes\TextToTextHeadline;
  61. use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
  62. use OCP\TaskProcessing\TaskTypes\TextToTextTopics;
  63. use Psr\Log\LoggerInterface;
  64. class Manager implements IManager {
  65. public const LEGACY_PREFIX_TEXTPROCESSING = 'legacy:TextProcessing:';
  66. public const LEGACY_PREFIX_TEXTTOIMAGE = 'legacy:TextToImage:';
  67. public const LEGACY_PREFIX_SPEECHTOTEXT = 'legacy:SpeechToText:';
  68. /** @var list<IProvider>|null */
  69. private ?array $providers = null;
  70. /** @var array<string,array{name: string, description: string, inputShape: array<string, ShapeDescriptor>, optionalInputShape: array<string, ShapeDescriptor>, outputShape: array<string, ShapeDescriptor>, optionalOutputShape: array<string, ShapeDescriptor>}>|null */
  71. private ?array $availableTaskTypes = null;
  72. private IAppData $appData;
  73. public function __construct(
  74. private Coordinator $coordinator,
  75. private IServerContainer $serverContainer,
  76. private LoggerInterface $logger,
  77. private TaskMapper $taskMapper,
  78. private IJobList $jobList,
  79. private IEventDispatcher $dispatcher,
  80. IAppDataFactory $appDataFactory,
  81. private IRootFolder $rootFolder,
  82. private \OCP\TextProcessing\IManager $textProcessingManager,
  83. private \OCP\TextToImage\IManager $textToImageManager,
  84. private \OCP\SpeechToText\ISpeechToTextManager $speechToTextManager,
  85. private \OCP\Share\IManager $shareManager,
  86. ) {
  87. $this->appData = $appDataFactory->get('core');
  88. }
  89. /**
  90. * @return IProvider[]
  91. */
  92. private function _getTextProcessingProviders(): array {
  93. $oldProviders = $this->textProcessingManager->getProviders();
  94. $newProviders = [];
  95. foreach ($oldProviders as $oldProvider) {
  96. $provider = new class($oldProvider) implements IProvider, ISynchronousProvider {
  97. private \OCP\TextProcessing\IProvider $provider;
  98. public function __construct(\OCP\TextProcessing\IProvider $provider) {
  99. $this->provider = $provider;
  100. }
  101. public function getId(): string {
  102. if ($this->provider instanceof \OCP\TextProcessing\IProviderWithId) {
  103. return $this->provider->getId();
  104. }
  105. return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider::class;
  106. }
  107. public function getName(): string {
  108. return $this->provider->getName();
  109. }
  110. public function getTaskTypeId(): string {
  111. return match ($this->provider->getTaskType()) {
  112. \OCP\TextProcessing\FreePromptTaskType::class => TextToText::ID,
  113. \OCP\TextProcessing\HeadlineTaskType::class => TextToTextHeadline::ID,
  114. \OCP\TextProcessing\TopicsTaskType::class => TextToTextTopics::ID,
  115. \OCP\TextProcessing\SummaryTaskType::class => TextToTextSummary::ID,
  116. default => Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider->getTaskType(),
  117. };
  118. }
  119. public function getExpectedRuntime(): int {
  120. if ($this->provider instanceof \OCP\TextProcessing\IProviderWithExpectedRuntime) {
  121. return $this->provider->getExpectedRuntime();
  122. }
  123. return 60;
  124. }
  125. public function getOptionalInputShape(): array {
  126. return [];
  127. }
  128. public function getOptionalOutputShape(): array {
  129. return [];
  130. }
  131. public function process(?string $userId, array $input, callable $reportProgress): array {
  132. if ($this->provider instanceof \OCP\TextProcessing\IProviderWithUserId) {
  133. $this->provider->setUserId($userId);
  134. }
  135. try {
  136. return ['output' => $this->provider->process($input['input'])];
  137. } catch(\RuntimeException $e) {
  138. throw new ProcessingException($e->getMessage(), 0, $e);
  139. }
  140. }
  141. };
  142. $newProviders[$provider->getId()] = $provider;
  143. }
  144. return $newProviders;
  145. }
  146. /**
  147. * @return ITaskType[]
  148. */
  149. private function _getTextProcessingTaskTypes(): array {
  150. $oldProviders = $this->textProcessingManager->getProviders();
  151. $newTaskTypes = [];
  152. foreach ($oldProviders as $oldProvider) {
  153. // These are already implemented in the TaskProcessing realm
  154. if (in_array($oldProvider->getTaskType(), [
  155. \OCP\TextProcessing\FreePromptTaskType::class,
  156. \OCP\TextProcessing\HeadlineTaskType::class,
  157. \OCP\TextProcessing\TopicsTaskType::class,
  158. \OCP\TextProcessing\SummaryTaskType::class
  159. ], true)) {
  160. continue;
  161. }
  162. $taskType = new class($oldProvider->getTaskType()) implements ITaskType {
  163. private string $oldTaskTypeClass;
  164. private \OCP\TextProcessing\ITaskType $oldTaskType;
  165. private IL10N $l;
  166. public function __construct(string $oldTaskTypeClass) {
  167. $this->oldTaskTypeClass = $oldTaskTypeClass;
  168. $this->oldTaskType = \OCP\Server::get($oldTaskTypeClass);
  169. $this->l = \OCP\Server::get(IFactory::class)->get('core');
  170. }
  171. public function getId(): string {
  172. return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->oldTaskTypeClass;
  173. }
  174. public function getName(): string {
  175. return $this->oldTaskType->getName();
  176. }
  177. public function getDescription(): string {
  178. return $this->oldTaskType->getDescription();
  179. }
  180. public function getInputShape(): array {
  181. return ['input' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)];
  182. }
  183. public function getOutputShape(): array {
  184. return ['output' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)];
  185. }
  186. };
  187. $newTaskTypes[$taskType->getId()] = $taskType;
  188. }
  189. return $newTaskTypes;
  190. }
  191. /**
  192. * @return IProvider[]
  193. */
  194. private function _getTextToImageProviders(): array {
  195. $oldProviders = $this->textToImageManager->getProviders();
  196. $newProviders = [];
  197. foreach ($oldProviders as $oldProvider) {
  198. $newProvider = new class($oldProvider, $this->appData) implements IProvider, ISynchronousProvider {
  199. private \OCP\TextToImage\IProvider $provider;
  200. private IAppData $appData;
  201. public function __construct(\OCP\TextToImage\IProvider $provider, IAppData $appData) {
  202. $this->provider = $provider;
  203. $this->appData = $appData;
  204. }
  205. public function getId(): string {
  206. return Manager::LEGACY_PREFIX_TEXTTOIMAGE . $this->provider->getId();
  207. }
  208. public function getName(): string {
  209. return $this->provider->getName();
  210. }
  211. public function getTaskTypeId(): string {
  212. return TextToImage::ID;
  213. }
  214. public function getExpectedRuntime(): int {
  215. return $this->provider->getExpectedRuntime();
  216. }
  217. public function getOptionalInputShape(): array {
  218. return [];
  219. }
  220. public function getOptionalOutputShape(): array {
  221. return [];
  222. }
  223. public function process(?string $userId, array $input, callable $reportProgress): array {
  224. try {
  225. $folder = $this->appData->getFolder('text2image');
  226. } catch(\OCP\Files\NotFoundException) {
  227. $folder = $this->appData->newFolder('text2image');
  228. }
  229. $resources = [];
  230. $files = [];
  231. for ($i = 0; $i < $input['numberOfImages']; $i++) {
  232. $file = $folder->newFile(time() . '-' . rand(1, 100000) . '-' . $i);
  233. $files[] = $file;
  234. $resource = $file->write();
  235. if ($resource !== false && $resource !== true && is_resource($resource)) {
  236. $resources[] = $resource;
  237. } else {
  238. throw new ProcessingException('Text2Image generation using provider "' . $this->getName() . '" failed: Couldn\'t open file to write.');
  239. }
  240. }
  241. if ($this->provider instanceof \OCP\TextToImage\IProviderWithUserId) {
  242. $this->provider->setUserId($userId);
  243. }
  244. try {
  245. $this->provider->generate($input['input'], $resources);
  246. } catch (\RuntimeException $e) {
  247. throw new ProcessingException($e->getMessage(), 0, $e);
  248. }
  249. for ($i = 0; $i < $input['numberOfImages']; $i++) {
  250. if (is_resource($resources[$i])) {
  251. // If $resource hasn't been closed yet, we'll do that here
  252. fclose($resources[$i]);
  253. }
  254. }
  255. return ['images' => array_map(fn (ISimpleFile $file) => $file->getContent(), $files)];
  256. }
  257. };
  258. $newProviders[$newProvider->getId()] = $newProvider;
  259. }
  260. return $newProviders;
  261. }
  262. /**
  263. * @return IProvider[]
  264. */
  265. private function _getSpeechToTextProviders(): array {
  266. $oldProviders = $this->speechToTextManager->getProviders();
  267. $newProviders = [];
  268. foreach ($oldProviders as $oldProvider) {
  269. $newProvider = new class($oldProvider, $this->rootFolder, $this->appData) implements IProvider, ISynchronousProvider {
  270. private ISpeechToTextProvider $provider;
  271. private IAppData $appData;
  272. private IRootFolder $rootFolder;
  273. public function __construct(ISpeechToTextProvider $provider, IRootFolder $rootFolder, IAppData $appData) {
  274. $this->provider = $provider;
  275. $this->rootFolder = $rootFolder;
  276. $this->appData = $appData;
  277. }
  278. public function getId(): string {
  279. if ($this->provider instanceof ISpeechToTextProviderWithId) {
  280. return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider->getId();
  281. }
  282. return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider::class;
  283. }
  284. public function getName(): string {
  285. return $this->provider->getName();
  286. }
  287. public function getTaskTypeId(): string {
  288. return AudioToText::ID;
  289. }
  290. public function getExpectedRuntime(): int {
  291. return 60;
  292. }
  293. public function getOptionalInputShape(): array {
  294. return [];
  295. }
  296. public function getOptionalOutputShape(): array {
  297. return [];
  298. }
  299. public function process(?string $userId, array $input, callable $reportProgress): array {
  300. try {
  301. $result = $this->provider->transcribeFile($input['input']);
  302. } catch (\RuntimeException $e) {
  303. throw new ProcessingException($e->getMessage(), 0, $e);
  304. }
  305. return ['output' => $result];
  306. }
  307. };
  308. $newProviders[$newProvider->getId()] = $newProvider;
  309. }
  310. return $newProviders;
  311. }
  312. /**
  313. * @return IProvider[]
  314. */
  315. private function _getProviders(): array {
  316. $context = $this->coordinator->getRegistrationContext();
  317. if ($context === null) {
  318. return [];
  319. }
  320. $providers = [];
  321. foreach ($context->getTaskProcessingProviders() as $providerServiceRegistration) {
  322. $class = $providerServiceRegistration->getService();
  323. try {
  324. /** @var IProvider $provider */
  325. $provider = $this->serverContainer->get($class);
  326. if (isset($providers[$provider->getId()])) {
  327. $this->logger->warning('Task processing provider ' . $class . ' is using ID ' . $provider->getId() . ' which is already used by ' . $providers[$provider->getId()]::class);
  328. }
  329. $providers[$provider->getId()] = $provider;
  330. } catch (\Throwable $e) {
  331. $this->logger->error('Failed to load task processing provider ' . $class, [
  332. 'exception' => $e,
  333. ]);
  334. }
  335. }
  336. $providers += $this->_getTextProcessingProviders() + $this->_getTextToImageProviders() + $this->_getSpeechToTextProviders();
  337. return $providers;
  338. }
  339. /**
  340. * @return ITaskType[]
  341. */
  342. private function _getTaskTypes(): array {
  343. $context = $this->coordinator->getRegistrationContext();
  344. if ($context === null) {
  345. return [];
  346. }
  347. // Default task types
  348. $taskTypes = [
  349. \OCP\TaskProcessing\TaskTypes\TextToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToText::class),
  350. \OCP\TaskProcessing\TaskTypes\TextToTextTopics::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextTopics::class),
  351. \OCP\TaskProcessing\TaskTypes\TextToTextHeadline::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextHeadline::class),
  352. \OCP\TaskProcessing\TaskTypes\TextToTextSummary::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextSummary::class),
  353. \OCP\TaskProcessing\TaskTypes\TextToImage::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToImage::class),
  354. \OCP\TaskProcessing\TaskTypes\AudioToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToText::class),
  355. ];
  356. foreach ($context->getTaskProcessingTaskTypes() as $providerServiceRegistration) {
  357. $class = $providerServiceRegistration->getService();
  358. try {
  359. /** @var ITaskType $provider */
  360. $taskType = $this->serverContainer->get($class);
  361. if (isset($taskTypes[$taskType->getId()])) {
  362. $this->logger->warning('Task processing task type ' . $class . ' is using ID ' . $taskType->getId() . ' which is already used by ' . $taskTypes[$taskType->getId()]::class);
  363. }
  364. $taskTypes[$taskType->getId()] = $taskType;
  365. } catch (\Throwable $e) {
  366. $this->logger->error('Failed to load task processing task type ' . $class, [
  367. 'exception' => $e,
  368. ]);
  369. }
  370. }
  371. $taskTypes += $this->_getTextProcessingTaskTypes();
  372. return $taskTypes;
  373. }
  374. /**
  375. * @param string $taskType
  376. * @return IProvider
  377. * @throws \OCP\TaskProcessing\Exception\Exception
  378. */
  379. private function _getPreferredProvider(string $taskType) {
  380. $providers = $this->getProviders();
  381. foreach ($providers as $provider) {
  382. if ($provider->getTaskTypeId() === $taskType) {
  383. return $provider;
  384. }
  385. }
  386. throw new \OCP\TaskProcessing\Exception\Exception('No matching provider found');
  387. }
  388. /**
  389. * @param ShapeDescriptor[] $spec
  390. * @param array $io
  391. * @return void
  392. * @throws ValidationException
  393. */
  394. private function validateInput(array $spec, array $io, bool $optional = false): void {
  395. foreach ($spec as $key => $descriptor) {
  396. $type = $descriptor->getShapeType();
  397. if (!isset($io[$key])) {
  398. if ($optional) {
  399. continue;
  400. }
  401. throw new ValidationException('Missing key: "' . $key . '"');
  402. }
  403. try {
  404. $type->validateInput($io[$key]);
  405. } catch (ValidationException $e) {
  406. throw new ValidationException('Failed to validate input key "' . $key . '": ' . $e->getMessage());
  407. }
  408. }
  409. }
  410. /**
  411. * @param ShapeDescriptor[] $spec
  412. * @param array $io
  413. * @param bool $optional
  414. * @return void
  415. * @throws ValidationException
  416. */
  417. private function validateOutput(array $spec, array $io, bool $optional = false): void {
  418. foreach ($spec as $key => $descriptor) {
  419. $type = $descriptor->getShapeType();
  420. if (!isset($io[$key])) {
  421. if ($optional) {
  422. continue;
  423. }
  424. throw new ValidationException('Missing key: "' . $key . '"');
  425. }
  426. try {
  427. $type->validateOutput($io[$key]);
  428. } catch (ValidationException $e) {
  429. throw new ValidationException('Failed to validate output key "' . $key . '": ' . $e->getMessage());
  430. }
  431. }
  432. }
  433. /**
  434. * @param array<array-key, T> $array The array to filter
  435. * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep
  436. * @return array<array-key, T>
  437. * @psalm-template T
  438. */
  439. private function removeSuperfluousArrayKeys(array $array, ...$specs): array {
  440. $keys = array_unique(array_reduce($specs, fn ($carry, $spec) => $carry + array_keys($spec), []));
  441. $values = array_map(fn (string $key) => $array[$key], $keys);
  442. return array_combine($keys, $values);
  443. }
  444. public function hasProviders(): bool {
  445. return count($this->getProviders()) !== 0;
  446. }
  447. public function getProviders(): array {
  448. if ($this->providers === null) {
  449. $this->providers = $this->_getProviders();
  450. }
  451. return $this->providers;
  452. }
  453. public function getAvailableTaskTypes(): array {
  454. if ($this->availableTaskTypes === null) {
  455. $taskTypes = $this->_getTaskTypes();
  456. $providers = $this->getProviders();
  457. $availableTaskTypes = [];
  458. foreach ($providers as $provider) {
  459. if (!isset($taskTypes[$provider->getTaskTypeId()])) {
  460. continue;
  461. }
  462. $taskType = $taskTypes[$provider->getTaskTypeId()];
  463. $availableTaskTypes[$provider->getTaskTypeId()] = [
  464. 'name' => $taskType->getName(),
  465. 'description' => $taskType->getDescription(),
  466. 'inputShape' => $taskType->getInputShape(),
  467. 'optionalInputShape' => $provider->getOptionalInputShape(),
  468. 'outputShape' => $taskType->getOutputShape(),
  469. 'optionalOutputShape' => $provider->getOptionalOutputShape(),
  470. ];
  471. }
  472. $this->availableTaskTypes = $availableTaskTypes;
  473. }
  474. return $this->availableTaskTypes;
  475. }
  476. public function canHandleTask(Task $task): bool {
  477. return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]);
  478. }
  479. public function scheduleTask(Task $task): void {
  480. if (!$this->canHandleTask($task)) {
  481. throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId());
  482. }
  483. $taskTypes = $this->getAvailableTaskTypes();
  484. $inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape'];
  485. $optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape'];
  486. // validate input
  487. $this->validateInput($inputShape, $task->getInput());
  488. $this->validateInput($optionalInputShape, $task->getInput(), true);
  489. // authenticate access to mentioned files
  490. $ids = [];
  491. foreach ($inputShape + $optionalInputShape as $key => $descriptor) {
  492. if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) {
  493. /** @var list<int>|int $inputSlot */
  494. $inputSlot = $task->getInput()[$key];
  495. if (is_array($inputSlot)) {
  496. $ids += $inputSlot;
  497. } else {
  498. $ids[] = $inputSlot;
  499. }
  500. }
  501. }
  502. foreach ($ids as $fileId) {
  503. $node = $this->rootFolder->getFirstNodeById($fileId);
  504. if ($node === null) {
  505. $node = $this->rootFolder->getFirstNodeByIdInPath($fileId, '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  506. if ($node === null) {
  507. throw new ValidationException('Could not find file ' . $fileId);
  508. }
  509. }
  510. /** @var array{users:array<string,array{node_id:int, node_path: string}>, remote: array<string,array{node_id:int, node_path: string}>, mail: array<string,array{node_id:int, node_path: string}>} $accessList */
  511. $accessList = $this->shareManager->getAccessList($node, true, true);
  512. $userIds = array_map(fn ($id) => strval($id), array_keys($accessList['users']));
  513. if (!in_array($task->getUserId(), $userIds)) {
  514. throw new UnauthorizedException('User ' . $task->getUserId() . ' does not have access to file ' . $fileId);
  515. }
  516. }
  517. // remove superfluous keys and set input
  518. $task->setInput($this->removeSuperfluousArrayKeys($task->getInput(), $inputShape, $optionalInputShape));
  519. $task->setStatus(Task::STATUS_SCHEDULED);
  520. $provider = $this->_getPreferredProvider($task->getTaskTypeId());
  521. // calculate expected completion time
  522. $completionExpectedAt = new \DateTime('now');
  523. $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
  524. $task->setCompletionExpectedAt($completionExpectedAt);
  525. // create a db entity and insert into db table
  526. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  527. $this->taskMapper->insert($taskEntity);
  528. // make sure the scheduler knows the id
  529. $task->setId($taskEntity->getId());
  530. // schedule synchronous job if the provider is synchronous
  531. if ($provider instanceof ISynchronousProvider) {
  532. $this->jobList->add(SynchronousBackgroundJob::class, null);
  533. }
  534. }
  535. public function deleteTask(Task $task): void {
  536. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  537. $this->taskMapper->delete($taskEntity);
  538. }
  539. public function getTask(int $id): Task {
  540. try {
  541. $taskEntity = $this->taskMapper->find($id);
  542. return $taskEntity->toPublicTask();
  543. } catch (DoesNotExistException $e) {
  544. throw new NotFoundException('Couldn\'t find task with id ' . $id, 0, $e);
  545. } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) {
  546. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  547. } catch (\JsonException $e) {
  548. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e);
  549. }
  550. }
  551. public function cancelTask(int $id): void {
  552. $task = $this->getTask($id);
  553. if ($task->getStatus() !== Task::STATUS_SCHEDULED && $task->getStatus() !== Task::STATUS_RUNNING) {
  554. return;
  555. }
  556. $task->setStatus(Task::STATUS_CANCELLED);
  557. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  558. try {
  559. $this->taskMapper->update($taskEntity);
  560. } catch (\OCP\DB\Exception $e) {
  561. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  562. }
  563. }
  564. public function setTaskProgress(int $id, float $progress): bool {
  565. // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently
  566. $task = $this->getTask($id);
  567. if ($task->getStatus() === Task::STATUS_CANCELLED) {
  568. return false;
  569. }
  570. $task->setStatus(Task::STATUS_RUNNING);
  571. $task->setProgress($progress);
  572. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  573. try {
  574. $this->taskMapper->update($taskEntity);
  575. } catch (\OCP\DB\Exception $e) {
  576. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  577. }
  578. return true;
  579. }
  580. public function setTaskResult(int $id, ?string $error, ?array $result): void {
  581. // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently
  582. $task = $this->getTask($id);
  583. if ($task->getStatus() === Task::STATUS_CANCELLED) {
  584. $this->logger->info('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' finished but was cancelled in the mean time. Moving on without storing result.');
  585. return;
  586. }
  587. if ($error !== null) {
  588. $task->setStatus(Task::STATUS_FAILED);
  589. $task->setErrorMessage($error);
  590. $this->logger->warning('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' failed with the following message: ' . $error);
  591. } elseif ($result !== null) {
  592. $taskTypes = $this->getAvailableTaskTypes();
  593. $outputShape = $taskTypes[$task->getTaskTypeId()]['outputShape'];
  594. $optionalOutputShape = $taskTypes[$task->getTaskTypeId()]['optionalOutputShape'];
  595. try {
  596. // validate output
  597. $this->validateOutput($outputShape, $result);
  598. $this->validateOutput($optionalOutputShape, $result, true);
  599. $output = $this->removeSuperfluousArrayKeys($result, $outputShape, $optionalOutputShape);
  600. // extract raw data and put it in files, replace it with file ids
  601. $output = $this->encapsulateOutputFileData($output, $outputShape, $optionalOutputShape);
  602. $task->setOutput($output);
  603. $task->setProgress(1);
  604. $task->setStatus(Task::STATUS_SUCCESSFUL);
  605. } catch (ValidationException $e) {
  606. $task->setProgress(1);
  607. $task->setStatus(Task::STATUS_FAILED);
  608. $error = 'The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec';
  609. $task->setErrorMessage($error);
  610. $this->logger->error($error, ['exception' => $e]);
  611. } catch (NotPermittedException $e) {
  612. $task->setProgress(1);
  613. $task->setStatus(Task::STATUS_FAILED);
  614. $error = 'The task was processed successfully but storing the output in a file failed';
  615. $task->setErrorMessage($error);
  616. $this->logger->error($error, ['exception' => $e]);
  617. }
  618. }
  619. $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task);
  620. try {
  621. $this->taskMapper->update($taskEntity);
  622. } catch (\OCP\DB\Exception $e) {
  623. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  624. }
  625. if ($task->getStatus() === Task::STATUS_SUCCESSFUL) {
  626. $event = new TaskSuccessfulEvent($task);
  627. } else {
  628. $event = new TaskFailedEvent($task, $error);
  629. }
  630. $this->dispatcher->dispatchTyped($event);
  631. }
  632. public function getNextScheduledTask(?string $taskTypeId = null): Task {
  633. try {
  634. $taskEntity = $this->taskMapper->findOldestScheduledByType($taskTypeId);
  635. return $taskEntity->toPublicTask();
  636. } catch (DoesNotExistException $e) {
  637. throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', 0, $e);
  638. } catch (\OCP\DB\Exception $e) {
  639. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  640. } catch (\JsonException $e) {
  641. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e);
  642. }
  643. }
  644. /**
  645. * Takes task input or output data and replaces fileIds with base64 data
  646. *
  647. * @param string|null $userId
  648. * @param array<array-key, list<numeric|string>|numeric|string> $input
  649. * @param ShapeDescriptor[] ...$specs the specs
  650. * @return array<array-key, list<File|numeric|string>|numeric|string|File>
  651. * @throws GenericFileException
  652. * @throws LockedException
  653. * @throws NotPermittedException
  654. * @throws ValidationException
  655. */
  656. public function fillInputFileData(?string $userId, array $input, ...$specs): array {
  657. if ($userId !== null) {
  658. \OC_Util::setupFS($userId);
  659. }
  660. $newInputOutput = [];
  661. $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []);
  662. foreach($spec as $key => $descriptor) {
  663. $type = $descriptor->getShapeType();
  664. if (!isset($input[$key])) {
  665. continue;
  666. }
  667. if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) {
  668. $newInputOutput[$key] = $input[$key];
  669. continue;
  670. }
  671. if ($type->value < 10) {
  672. $node = $this->rootFolder->getFirstNodeById((int)$input[$key]);
  673. if ($node === null) {
  674. $node = $this->rootFolder->getFirstNodeByIdInPath((int)$input[$key], '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  675. if (!$node instanceof File) {
  676. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  677. }
  678. } elseif (!$node instanceof File) {
  679. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  680. }
  681. // TODO: Validate if userId has access to this file
  682. $newInputOutput[$key] = $node;
  683. } else {
  684. $newInputOutput[$key] = [];
  685. foreach ($input[$key] as $item) {
  686. $node = $this->rootFolder->getFirstNodeById((int)$item);
  687. if ($node === null) {
  688. $node = $this->rootFolder->getFirstNodeByIdInPath((int)$item, '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  689. if (!$node instanceof File) {
  690. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  691. }
  692. } elseif (!$node instanceof File) {
  693. throw new ValidationException('File id given for key "' . $key . '" is not a file');
  694. }
  695. // TODO: Validate if userId has access to this file
  696. $newInputOutput[$key][] = $node;
  697. }
  698. }
  699. }
  700. return $newInputOutput;
  701. }
  702. public function getUserTask(int $id, ?string $userId): Task {
  703. try {
  704. $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId);
  705. return $taskEntity->toPublicTask();
  706. } catch (DoesNotExistException $e) {
  707. throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', 0, $e);
  708. } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) {
  709. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e);
  710. } catch (\JsonException $e) {
  711. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e);
  712. }
  713. }
  714. public function getUserTasks(?string $userId, ?string $taskTypeId = null, ?string $customId = null): array {
  715. try {
  716. $taskEntities = $this->taskMapper->findByUserAndTaskType($userId, $taskTypeId, $customId);
  717. return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities);
  718. } catch (\OCP\DB\Exception $e) {
  719. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the tasks', 0, $e);
  720. } catch (\JsonException $e) {
  721. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the tasks', 0, $e);
  722. }
  723. }
  724. public function getUserTasksByApp(?string $userId, string $appId, ?string $customId = null): array {
  725. try {
  726. $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $customId);
  727. return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities);
  728. } catch (\OCP\DB\Exception $e) {
  729. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding a task', 0, $e);
  730. } catch (\JsonException $e) {
  731. throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding a task', 0, $e);
  732. }
  733. }
  734. /**
  735. *Takes task input or output and replaces base64 data with file ids
  736. *
  737. * @param array $output
  738. * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep
  739. * @return array
  740. * @throws NotPermittedException
  741. */
  742. public function encapsulateOutputFileData(array $output, ...$specs): array {
  743. $newOutput = [];
  744. try {
  745. $folder = $this->appData->getFolder('TaskProcessing');
  746. } catch (\OCP\Files\NotFoundException) {
  747. $folder = $this->appData->newFolder('TaskProcessing');
  748. }
  749. $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []);
  750. foreach($spec as $key => $descriptor) {
  751. $type = $descriptor->getShapeType();
  752. if (!isset($output[$key])) {
  753. continue;
  754. }
  755. if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) {
  756. $newOutput[$key] = $output[$key];
  757. continue;
  758. }
  759. if ($type->value < 10) {
  760. /** @var SimpleFile $file */
  761. $file = $folder->newFile((string) rand(0, 10000000), $output[$key]);
  762. $newOutput[$key] = $file->getId(); // polymorphic call to SimpleFile
  763. } else {
  764. $newOutput = [];
  765. foreach ($output[$key] as $item) {
  766. /** @var SimpleFile $file */
  767. $file = $folder->newFile((string) rand(0, 10000000), $item);
  768. $newOutput[$key][] = $file->getId();
  769. }
  770. }
  771. }
  772. return $newOutput;
  773. }
  774. /**
  775. * @param Task $task
  776. * @return array<array-key, list<numeric|string|File>|numeric|string|File>
  777. * @throws GenericFileException
  778. * @throws LockedException
  779. * @throws NotPermittedException
  780. * @throws ValidationException
  781. */
  782. public function prepareInputData(Task $task): array {
  783. $taskTypes = $this->getAvailableTaskTypes();
  784. $inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape'];
  785. $optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape'];
  786. $input = $task->getInput();
  787. // validate input, again for good measure (should have been validated in scheduleTask)
  788. $this->validateInput($inputShape, $input);
  789. $this->validateInput($optionalInputShape, $input, true);
  790. $input = $this->removeSuperfluousArrayKeys($input, $inputShape, $optionalInputShape);
  791. $input = $this->fillInputFileData($task->getUserId(), $input, $inputShape, $optionalInputShape);
  792. return $input;
  793. }
  794. }