TaskProcessingTest.php 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. <?php
  2. /**
  3. * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
  4. * SPDX-License-Identifier: AGPL-3.0-or-later
  5. */
  6. namespace Test\TaskProcessing;
  7. use OC\AppFramework\Bootstrap\Coordinator;
  8. use OC\AppFramework\Bootstrap\RegistrationContext;
  9. use OC\AppFramework\Bootstrap\ServiceRegistration;
  10. use OC\EventDispatcher\EventDispatcher;
  11. use OC\TaskProcessing\Db\TaskMapper;
  12. use OC\TaskProcessing\Manager;
  13. use OC\TaskProcessing\RemoveOldTasksBackgroundJob;
  14. use OCP\App\IAppManager;
  15. use OCP\AppFramework\Utility\ITimeFactory;
  16. use OCP\BackgroundJob\IJobList;
  17. use OCP\EventDispatcher\IEventDispatcher;
  18. use OCP\Files\AppData\IAppDataFactory;
  19. use OCP\Files\Config\ICachedMountInfo;
  20. use OCP\Files\Config\IUserMountCache;
  21. use OCP\Files\IRootFolder;
  22. use OCP\Http\Client\IClientService;
  23. use OCP\IConfig;
  24. use OCP\IDBConnection;
  25. use OCP\IServerContainer;
  26. use OCP\IUser;
  27. use OCP\IUserManager;
  28. use OCP\TaskProcessing\EShapeType;
  29. use OCP\TaskProcessing\Events\TaskFailedEvent;
  30. use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
  31. use OCP\TaskProcessing\Exception\NotFoundException;
  32. use OCP\TaskProcessing\Exception\ProcessingException;
  33. use OCP\TaskProcessing\Exception\UnauthorizedException;
  34. use OCP\TaskProcessing\Exception\ValidationException;
  35. use OCP\TaskProcessing\IManager;
  36. use OCP\TaskProcessing\IProvider;
  37. use OCP\TaskProcessing\ISynchronousProvider;
  38. use OCP\TaskProcessing\ITaskType;
  39. use OCP\TaskProcessing\ShapeDescriptor;
  40. use OCP\TaskProcessing\Task;
  41. use OCP\TaskProcessing\TaskTypes\TextToImage;
  42. use OCP\TaskProcessing\TaskTypes\TextToText;
  43. use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
  44. use OCP\TextProcessing\SummaryTaskType;
  45. use PHPUnit\Framework\Constraint\IsInstanceOf;
  46. use Psr\Log\LoggerInterface;
  47. use Test\BackgroundJob\DummyJobList;
  48. class AudioToImage implements ITaskType {
  49. public const ID = 'test:audiotoimage';
  50. public function getId(): string {
  51. return self::ID;
  52. }
  53. public function getName(): string {
  54. return self::class;
  55. }
  56. public function getDescription(): string {
  57. return self::class;
  58. }
  59. public function getInputShape(): array {
  60. return [
  61. 'audio' => new ShapeDescriptor('Audio', 'The audio', EShapeType::Audio),
  62. ];
  63. }
  64. public function getOutputShape(): array {
  65. return [
  66. 'spectrogram' => new ShapeDescriptor('Spectrogram', 'The audio spectrogram', EShapeType::Image),
  67. ];
  68. }
  69. }
  70. class AsyncProvider implements IProvider {
  71. public function getId(): string {
  72. return 'test:sync:success';
  73. }
  74. public function getName(): string {
  75. return self::class;
  76. }
  77. public function getTaskTypeId(): string {
  78. return AudioToImage::ID;
  79. }
  80. public function getExpectedRuntime(): int {
  81. return 10;
  82. }
  83. public function getOptionalInputShape(): array {
  84. return [
  85. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  86. ];
  87. }
  88. public function getOptionalOutputShape(): array {
  89. return [
  90. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  91. ];
  92. }
  93. public function getInputShapeEnumValues(): array {
  94. return [];
  95. }
  96. public function getInputShapeDefaults(): array {
  97. return [];
  98. }
  99. public function getOptionalInputShapeEnumValues(): array {
  100. return [];
  101. }
  102. public function getOptionalInputShapeDefaults(): array {
  103. return [];
  104. }
  105. public function getOutputShapeEnumValues(): array {
  106. return [];
  107. }
  108. public function getOptionalOutputShapeEnumValues(): array {
  109. return [];
  110. }
  111. }
  112. class SuccessfulSyncProvider implements IProvider, ISynchronousProvider {
  113. public function getId(): string {
  114. return 'test:sync:success';
  115. }
  116. public function getName(): string {
  117. return self::class;
  118. }
  119. public function getTaskTypeId(): string {
  120. return TextToText::ID;
  121. }
  122. public function getExpectedRuntime(): int {
  123. return 10;
  124. }
  125. public function getOptionalInputShape(): array {
  126. return [
  127. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  128. ];
  129. }
  130. public function getOptionalOutputShape(): array {
  131. return [
  132. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  133. ];
  134. }
  135. public function process(?string $userId, array $input, callable $reportProgress): array {
  136. return ['output' => $input['input']];
  137. }
  138. public function getInputShapeEnumValues(): array {
  139. return [];
  140. }
  141. public function getInputShapeDefaults(): array {
  142. return [];
  143. }
  144. public function getOptionalInputShapeEnumValues(): array {
  145. return [];
  146. }
  147. public function getOptionalInputShapeDefaults(): array {
  148. return [];
  149. }
  150. public function getOutputShapeEnumValues(): array {
  151. return [];
  152. }
  153. public function getOptionalOutputShapeEnumValues(): array {
  154. return [];
  155. }
  156. }
  157. class FailingSyncProvider implements IProvider, ISynchronousProvider {
  158. public const ERROR_MESSAGE = 'Failure';
  159. public function getId(): string {
  160. return 'test:sync:fail';
  161. }
  162. public function getName(): string {
  163. return self::class;
  164. }
  165. public function getTaskTypeId(): string {
  166. return TextToText::ID;
  167. }
  168. public function getExpectedRuntime(): int {
  169. return 10;
  170. }
  171. public function getOptionalInputShape(): array {
  172. return [
  173. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  174. ];
  175. }
  176. public function getOptionalOutputShape(): array {
  177. return [
  178. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  179. ];
  180. }
  181. public function process(?string $userId, array $input, callable $reportProgress): array {
  182. throw new ProcessingException(self::ERROR_MESSAGE);
  183. }
  184. public function getInputShapeEnumValues(): array {
  185. return [];
  186. }
  187. public function getInputShapeDefaults(): array {
  188. return [];
  189. }
  190. public function getOptionalInputShapeEnumValues(): array {
  191. return [];
  192. }
  193. public function getOptionalInputShapeDefaults(): array {
  194. return [];
  195. }
  196. public function getOutputShapeEnumValues(): array {
  197. return [];
  198. }
  199. public function getOptionalOutputShapeEnumValues(): array {
  200. return [];
  201. }
  202. }
  203. class BrokenSyncProvider implements IProvider, ISynchronousProvider {
  204. public function getId(): string {
  205. return 'test:sync:broken-output';
  206. }
  207. public function getName(): string {
  208. return self::class;
  209. }
  210. public function getTaskTypeId(): string {
  211. return TextToText::ID;
  212. }
  213. public function getExpectedRuntime(): int {
  214. return 10;
  215. }
  216. public function getOptionalInputShape(): array {
  217. return [
  218. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  219. ];
  220. }
  221. public function getOptionalOutputShape(): array {
  222. return [
  223. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  224. ];
  225. }
  226. public function process(?string $userId, array $input, callable $reportProgress): array {
  227. return [];
  228. }
  229. public function getInputShapeEnumValues(): array {
  230. return [];
  231. }
  232. public function getInputShapeDefaults(): array {
  233. return [];
  234. }
  235. public function getOptionalInputShapeEnumValues(): array {
  236. return [];
  237. }
  238. public function getOptionalInputShapeDefaults(): array {
  239. return [];
  240. }
  241. public function getOutputShapeEnumValues(): array {
  242. return [];
  243. }
  244. public function getOptionalOutputShapeEnumValues(): array {
  245. return [];
  246. }
  247. }
  248. class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
  249. public bool $ran = false;
  250. public function getName(): string {
  251. return 'TEST Vanilla LLM Provider';
  252. }
  253. public function process(string $prompt): string {
  254. $this->ran = true;
  255. return $prompt . ' Summarize';
  256. }
  257. public function getTaskType(): string {
  258. return SummaryTaskType::class;
  259. }
  260. }
  261. class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
  262. public bool $ran = false;
  263. public function getName(): string {
  264. return 'TEST Vanilla LLM Provider';
  265. }
  266. public function process(string $prompt): string {
  267. $this->ran = true;
  268. throw new \Exception('ERROR');
  269. }
  270. public function getTaskType(): string {
  271. return SummaryTaskType::class;
  272. }
  273. }
  274. class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider {
  275. public bool $ran = false;
  276. public function getId(): string {
  277. return 'test:successful';
  278. }
  279. public function getName(): string {
  280. return 'TEST Provider';
  281. }
  282. public function generate(string $prompt, array $resources): void {
  283. $this->ran = true;
  284. foreach ($resources as $resource) {
  285. fwrite($resource, 'test');
  286. }
  287. }
  288. public function getExpectedRuntime(): int {
  289. return 1;
  290. }
  291. }
  292. class FailingTextToImageProvider implements \OCP\TextToImage\IProvider {
  293. public bool $ran = false;
  294. public function getId(): string {
  295. return 'test:failing';
  296. }
  297. public function getName(): string {
  298. return 'TEST Provider';
  299. }
  300. public function generate(string $prompt, array $resources): void {
  301. $this->ran = true;
  302. throw new \RuntimeException('ERROR');
  303. }
  304. public function getExpectedRuntime(): int {
  305. return 1;
  306. }
  307. }
  308. /**
  309. * @group DB
  310. */
  311. class TaskProcessingTest extends \Test\TestCase {
  312. private IManager $manager;
  313. private Coordinator $coordinator;
  314. private array $providers;
  315. private IServerContainer $serverContainer;
  316. private IEventDispatcher $eventDispatcher;
  317. private RegistrationContext $registrationContext;
  318. private TaskMapper $taskMapper;
  319. private IJobList $jobList;
  320. private IUserMountCache $userMountCache;
  321. private IRootFolder $rootFolder;
  322. public const TEST_USER = 'testuser';
  323. protected function setUp(): void {
  324. parent::setUp();
  325. $this->providers = [
  326. SuccessfulSyncProvider::class => new SuccessfulSyncProvider(),
  327. FailingSyncProvider::class => new FailingSyncProvider(),
  328. BrokenSyncProvider::class => new BrokenSyncProvider(),
  329. AsyncProvider::class => new AsyncProvider(),
  330. AudioToImage::class => new AudioToImage(),
  331. SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(),
  332. FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(),
  333. SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(),
  334. FailingTextToImageProvider::class => new FailingTextToImageProvider(),
  335. ];
  336. $userManager = \OCP\Server::get(IUserManager::class);
  337. if (!$userManager->userExists(self::TEST_USER)) {
  338. $userManager->createUser(self::TEST_USER, 'test');
  339. }
  340. $this->serverContainer = $this->createMock(IServerContainer::class);
  341. $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
  342. return $this->providers[$class];
  343. });
  344. $this->eventDispatcher = new EventDispatcher(
  345. new \Symfony\Component\EventDispatcher\EventDispatcher(),
  346. $this->serverContainer,
  347. \OC::$server->get(LoggerInterface::class),
  348. );
  349. $this->registrationContext = $this->createMock(RegistrationContext::class);
  350. $this->coordinator = $this->createMock(Coordinator::class);
  351. $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);
  352. $this->rootFolder = \OCP\Server::get(IRootFolder::class);
  353. $this->taskMapper = \OCP\Server::get(TaskMapper::class);
  354. $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
  355. $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
  356. });
  357. $config = $this->createMock(IConfig::class);
  358. $config->method('getAppValue')
  359. ->with('core', 'ai.textprocessing_provider_preferences', '')
  360. ->willReturn('');
  361. $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
  362. $text2imageManager = new \OC\TextToImage\Manager(
  363. $this->serverContainer,
  364. $this->coordinator,
  365. \OC::$server->get(LoggerInterface::class),
  366. $this->jobList,
  367. \OC::$server->get(\OC\TextToImage\Db\TaskMapper::class),
  368. \OC::$server->get(IConfig::class),
  369. \OC::$server->get(IAppDataFactory::class),
  370. );
  371. $this->userMountCache = $this->createMock(IUserMountCache::class);
  372. $this->manager = new Manager(
  373. \OC::$server->get(IConfig::class),
  374. $this->coordinator,
  375. $this->serverContainer,
  376. \OC::$server->get(LoggerInterface::class),
  377. $this->taskMapper,
  378. $this->jobList,
  379. $this->eventDispatcher,
  380. \OC::$server->get(IAppDataFactory::class),
  381. \OC::$server->get(IRootFolder::class),
  382. $text2imageManager,
  383. $this->userMountCache,
  384. \OC::$server->get(IClientService::class),
  385. \OC::$server->get(IAppManager::class),
  386. );
  387. }
  388. private function getFile(string $name, string $content): \OCP\Files\File {
  389. $folder = $this->rootFolder->getUserFolder(self::TEST_USER);
  390. $file = $folder->newFile($name, $content);
  391. return $file;
  392. }
  393. public function testShouldNotHaveAnyProviders(): void {
  394. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  395. self::assertCount(0, $this->manager->getAvailableTaskTypes());
  396. self::assertFalse($this->manager->hasProviders());
  397. self::expectException(\OCP\TaskProcessing\Exception\PreConditionNotMetException::class);
  398. $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null));
  399. }
  400. public function testProviderShouldBeRegisteredAndTaskFailValidation(): void {
  401. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  402. new ServiceRegistration('test', BrokenSyncProvider::class)
  403. ]);
  404. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  405. self::assertTrue($this->manager->hasProviders());
  406. $task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null);
  407. self::assertNull($task->getId());
  408. self::expectException(ValidationException::class);
  409. $this->manager->scheduleTask($task);
  410. }
  411. public function testProviderShouldBeRegisteredAndTaskWithFilesFailValidation(): void {
  412. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  413. new ServiceRegistration('test', AudioToImage::class)
  414. ]);
  415. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  416. new ServiceRegistration('test', AsyncProvider::class)
  417. ]);
  418. $user = $this->createMock(IUser::class);
  419. $user->expects($this->any())->method('getUID')->willReturn(null);
  420. $mount = $this->createMock(ICachedMountInfo::class);
  421. $mount->expects($this->any())->method('getUser')->willReturn($user);
  422. $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
  423. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  424. self::assertTrue($this->manager->hasProviders());
  425. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  426. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', null);
  427. self::assertNull($task->getId());
  428. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  429. self::expectException(UnauthorizedException::class);
  430. $this->manager->scheduleTask($task);
  431. }
  432. public function testProviderShouldBeRegisteredAndFail(): void {
  433. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  434. new ServiceRegistration('test', FailingSyncProvider::class)
  435. ]);
  436. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  437. self::assertTrue($this->manager->hasProviders());
  438. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  439. self::assertNull($task->getId());
  440. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  441. $this->manager->scheduleTask($task);
  442. self::assertNotNull($task->getId());
  443. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  444. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  445. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  446. \OCP\Server::get(ITimeFactory::class),
  447. $this->manager,
  448. $this->jobList,
  449. \OCP\Server::get(LoggerInterface::class),
  450. );
  451. $backgroundJob->start($this->jobList);
  452. $task = $this->manager->getTask($task->getId());
  453. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  454. self::assertEquals(FailingSyncProvider::ERROR_MESSAGE, $task->getErrorMessage());
  455. }
  456. public function testProviderShouldBeRegisteredAndFailOutputValidation(): void {
  457. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  458. new ServiceRegistration('test', BrokenSyncProvider::class)
  459. ]);
  460. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  461. self::assertTrue($this->manager->hasProviders());
  462. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  463. self::assertNull($task->getId());
  464. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  465. $this->manager->scheduleTask($task);
  466. self::assertNotNull($task->getId());
  467. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  468. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  469. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  470. \OCP\Server::get(ITimeFactory::class),
  471. $this->manager,
  472. $this->jobList,
  473. \OCP\Server::get(LoggerInterface::class),
  474. );
  475. $backgroundJob->start($this->jobList);
  476. $task = $this->manager->getTask($task->getId());
  477. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  478. self::assertEquals('The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec', $task->getErrorMessage());
  479. }
  480. public function testProviderShouldBeRegisteredAndRun(): void {
  481. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  482. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  483. ]);
  484. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  485. $taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]];
  486. self::assertTrue(isset($taskTypeStruct['inputShape']['input']));
  487. self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType());
  488. self::assertTrue(isset($taskTypeStruct['optionalInputShape']['optionalKey']));
  489. self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalInputShape']['optionalKey']->getShapeType());
  490. self::assertTrue(isset($taskTypeStruct['outputShape']['output']));
  491. self::assertEquals(EShapeType::Text, $taskTypeStruct['outputShape']['output']->getShapeType());
  492. self::assertTrue(isset($taskTypeStruct['optionalOutputShape']['optionalKey']));
  493. self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalOutputShape']['optionalKey']->getShapeType());
  494. self::assertTrue($this->manager->hasProviders());
  495. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  496. self::assertNull($task->getId());
  497. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  498. $this->manager->scheduleTask($task);
  499. self::assertNotNull($task->getId());
  500. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  501. // Task object retrieved from db is up-to-date
  502. $task2 = $this->manager->getTask($task->getId());
  503. self::assertEquals($task->getId(), $task2->getId());
  504. self::assertEquals(['input' => 'Hello'], $task2->getInput());
  505. self::assertNull($task2->getOutput());
  506. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  507. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  508. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  509. \OCP\Server::get(ITimeFactory::class),
  510. $this->manager,
  511. $this->jobList,
  512. \OCP\Server::get(LoggerInterface::class),
  513. );
  514. $backgroundJob->start($this->jobList);
  515. $task = $this->manager->getTask($task->getId());
  516. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is '. $task->getStatus() . ' with error message: ' . $task->getErrorMessage());
  517. self::assertEquals(['output' => 'Hello'], $task->getOutput());
  518. self::assertEquals(1, $task->getProgress());
  519. }
  520. public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningRawFileData(): void {
  521. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  522. new ServiceRegistration('test', AudioToImage::class)
  523. ]);
  524. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  525. new ServiceRegistration('test', AsyncProvider::class)
  526. ]);
  527. $user = $this->createMock(IUser::class);
  528. $user->expects($this->any())->method('getUID')->willReturn('testuser');
  529. $mount = $this->createMock(ICachedMountInfo::class);
  530. $mount->expects($this->any())->method('getUser')->willReturn($user);
  531. $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
  532. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  533. self::assertTrue($this->manager->hasProviders());
  534. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  535. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser');
  536. self::assertNull($task->getId());
  537. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  538. $this->manager->scheduleTask($task);
  539. self::assertNotNull($task->getId());
  540. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  541. // Task object retrieved from db is up-to-date
  542. $task2 = $this->manager->getTask($task->getId());
  543. self::assertEquals($task->getId(), $task2->getId());
  544. self::assertEquals(['audio' => $audioId], $task2->getInput());
  545. self::assertNull($task2->getOutput());
  546. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  547. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  548. $this->manager->setTaskProgress($task2->getId(), 0.1);
  549. $input = $this->manager->prepareInputData($task2);
  550. self::assertTrue(isset($input['audio']));
  551. self::assertInstanceOf(\OCP\Files\File::class, $input['audio']);
  552. self::assertEquals($audioId, $input['audio']->getId());
  553. $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => 'World']);
  554. $task = $this->manager->getTask($task->getId());
  555. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  556. self::assertEquals(1, $task->getProgress());
  557. self::assertTrue(isset($task->getOutput()['spectrogram']));
  558. $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['spectrogram'], '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  559. self::assertNotNull($node);
  560. self::assertInstanceOf(\OCP\Files\File::class, $node);
  561. self::assertEquals('World', $node->getContent());
  562. }
  563. public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningFileIds(): void {
  564. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  565. new ServiceRegistration('test', AudioToImage::class)
  566. ]);
  567. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  568. new ServiceRegistration('test', AsyncProvider::class)
  569. ]);
  570. $user = $this->createMock(IUser::class);
  571. $user->expects($this->any())->method('getUID')->willReturn('testuser');
  572. $mount = $this->createMock(ICachedMountInfo::class);
  573. $mount->expects($this->any())->method('getUser')->willReturn($user);
  574. $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
  575. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  576. self::assertTrue($this->manager->hasProviders());
  577. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  578. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser');
  579. self::assertNull($task->getId());
  580. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  581. $this->manager->scheduleTask($task);
  582. self::assertNotNull($task->getId());
  583. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  584. // Task object retrieved from db is up-to-date
  585. $task2 = $this->manager->getTask($task->getId());
  586. self::assertEquals($task->getId(), $task2->getId());
  587. self::assertEquals(['audio' => $audioId], $task2->getInput());
  588. self::assertNull($task2->getOutput());
  589. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  590. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  591. $this->manager->setTaskProgress($task2->getId(), 0.1);
  592. $input = $this->manager->prepareInputData($task2);
  593. self::assertTrue(isset($input['audio']));
  594. self::assertInstanceOf(\OCP\Files\File::class, $input['audio']);
  595. self::assertEquals($audioId, $input['audio']->getId());
  596. $outputFileId = $this->getFile('audioOutput', 'World')->getId();
  597. $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => $outputFileId], true);
  598. $task = $this->manager->getTask($task->getId());
  599. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  600. self::assertEquals(1, $task->getProgress());
  601. self::assertTrue(isset($task->getOutput()['spectrogram']));
  602. $node = $this->rootFolder->getFirstNodeById($task->getOutput()['spectrogram']);
  603. self::assertNotNull($node, 'fileId:' . $task->getOutput()['spectrogram']);
  604. self::assertInstanceOf(\OCP\Files\File::class, $node);
  605. self::assertEquals('World', $node->getContent());
  606. }
  607. public function testNonexistentTask(): void {
  608. $this->expectException(\OCP\TaskProcessing\Exception\NotFoundException::class);
  609. $this->manager->getTask(2147483646);
  610. }
  611. public function testOldTasksShouldBeCleanedUp(): void {
  612. $currentTime = new \DateTime('now');
  613. $timeFactory = $this->createMock(ITimeFactory::class);
  614. $timeFactory->expects($this->any())->method('getDateTime')->willReturnCallback(fn () => $currentTime);
  615. $timeFactory->expects($this->any())->method('getTime')->willReturnCallback(fn () => $currentTime->getTimestamp());
  616. $this->taskMapper = new TaskMapper(
  617. \OCP\Server::get(IDBConnection::class),
  618. $timeFactory,
  619. );
  620. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  621. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  622. ]);
  623. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  624. self::assertTrue($this->manager->hasProviders());
  625. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  626. $this->manager->scheduleTask($task);
  627. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  628. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  629. \OCP\Server::get(ITimeFactory::class),
  630. $this->manager,
  631. $this->jobList,
  632. \OCP\Server::get(LoggerInterface::class),
  633. );
  634. $backgroundJob->start($this->jobList);
  635. $task = $this->manager->getTask($task->getId());
  636. $currentTime = $currentTime->add(new \DateInterval('P1Y'));
  637. // run background job
  638. $bgJob = new RemoveOldTasksBackgroundJob(
  639. $timeFactory,
  640. $this->taskMapper,
  641. \OC::$server->get(LoggerInterface::class),
  642. \OCP\Server::get(IAppDataFactory::class),
  643. );
  644. $bgJob->setArgument([]);
  645. $bgJob->start($this->jobList);
  646. $this->expectException(NotFoundException::class);
  647. $this->manager->getTask($task->getId());
  648. }
  649. public function testShouldTransparentlyHandleTextProcessingProviders(): void {
  650. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  651. new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class)
  652. ]);
  653. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  654. ]);
  655. $taskTypes = $this->manager->getAvailableTaskTypes();
  656. self::assertCount(1, $taskTypes);
  657. self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
  658. self::assertTrue($this->manager->hasProviders());
  659. $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
  660. $this->manager->scheduleTask($task);
  661. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  662. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  663. \OCP\Server::get(ITimeFactory::class),
  664. $this->manager,
  665. $this->jobList,
  666. \OCP\Server::get(LoggerInterface::class),
  667. );
  668. $backgroundJob->start($this->jobList);
  669. $task = $this->manager->getTask($task->getId());
  670. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  671. self::assertIsArray($task->getOutput());
  672. self::assertTrue(isset($task->getOutput()['output']));
  673. self::assertEquals('Hello Summarize', $task->getOutput()['output']);
  674. self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran);
  675. }
  676. public function testShouldTransparentlyHandleFailingTextProcessingProviders(): void {
  677. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  678. new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class)
  679. ]);
  680. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  681. ]);
  682. $taskTypes = $this->manager->getAvailableTaskTypes();
  683. self::assertCount(1, $taskTypes);
  684. self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
  685. self::assertTrue($this->manager->hasProviders());
  686. $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
  687. $this->manager->scheduleTask($task);
  688. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  689. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  690. \OCP\Server::get(ITimeFactory::class),
  691. $this->manager,
  692. $this->jobList,
  693. \OCP\Server::get(LoggerInterface::class),
  694. );
  695. $backgroundJob->start($this->jobList);
  696. $task = $this->manager->getTask($task->getId());
  697. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  698. self::assertTrue($task->getOutput() === null);
  699. self::assertEquals('ERROR', $task->getErrorMessage());
  700. self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran);
  701. }
  702. public function testShouldTransparentlyHandleText2ImageProviders(): void {
  703. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
  704. new ServiceRegistration('test', SuccessfulTextToImageProvider::class)
  705. ]);
  706. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  707. ]);
  708. $taskTypes = $this->manager->getAvailableTaskTypes();
  709. self::assertCount(1, $taskTypes);
  710. self::assertTrue(isset($taskTypes[TextToImage::ID]));
  711. self::assertTrue($this->manager->hasProviders());
  712. $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
  713. $this->manager->scheduleTask($task);
  714. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  715. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  716. \OCP\Server::get(ITimeFactory::class),
  717. $this->manager,
  718. $this->jobList,
  719. \OCP\Server::get(LoggerInterface::class),
  720. );
  721. $backgroundJob->start($this->jobList);
  722. $task = $this->manager->getTask($task->getId());
  723. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  724. self::assertIsArray($task->getOutput());
  725. self::assertTrue(isset($task->getOutput()['images']));
  726. self::assertIsArray($task->getOutput()['images']);
  727. self::assertCount(3, $task->getOutput()['images']);
  728. self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran);
  729. $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['images'][0], '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  730. self::assertNotNull($node);
  731. self::assertInstanceOf(\OCP\Files\File::class, $node);
  732. self::assertEquals('test', $node->getContent());
  733. }
  734. public function testShouldTransparentlyHandleFailingText2ImageProviders(): void {
  735. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
  736. new ServiceRegistration('test', FailingTextToImageProvider::class)
  737. ]);
  738. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  739. ]);
  740. $taskTypes = $this->manager->getAvailableTaskTypes();
  741. self::assertCount(1, $taskTypes);
  742. self::assertTrue(isset($taskTypes[TextToImage::ID]));
  743. self::assertTrue($this->manager->hasProviders());
  744. $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
  745. $this->manager->scheduleTask($task);
  746. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  747. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  748. \OCP\Server::get(ITimeFactory::class),
  749. $this->manager,
  750. $this->jobList,
  751. \OCP\Server::get(LoggerInterface::class),
  752. );
  753. $backgroundJob->start($this->jobList);
  754. $task = $this->manager->getTask($task->getId());
  755. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  756. self::assertTrue($task->getOutput() === null);
  757. self::assertEquals('ERROR', $task->getErrorMessage());
  758. self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran);
  759. }
  760. }