TextProcessingTest.php 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. <?php
  2. /**
  3. * SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
  4. * SPDX-License-Identifier: AGPL-3.0-or-later
  5. */
  6. namespace Test\TextProcessing;
  7. use OC\AppFramework\Bootstrap\Coordinator;
  8. use OC\AppFramework\Bootstrap\RegistrationContext;
  9. use OC\AppFramework\Bootstrap\ServiceRegistration;
  10. use OC\EventDispatcher\EventDispatcher;
  11. use OC\TextProcessing\Db\Task as DbTask;
  12. use OC\TextProcessing\Db\TaskMapper;
  13. use OC\TextProcessing\Manager;
  14. use OC\TextProcessing\RemoveOldTasksBackgroundJob;
  15. use OC\TextProcessing\TaskBackgroundJob;
  16. use OCP\AppFramework\Db\DoesNotExistException;
  17. use OCP\AppFramework\Utility\ITimeFactory;
  18. use OCP\BackgroundJob\IJobList;
  19. use OCP\Common\Exception\NotFoundException;
  20. use OCP\EventDispatcher\IEventDispatcher;
  21. use OCP\IConfig;
  22. use OCP\IServerContainer;
  23. use OCP\PreConditionNotMetException;
  24. use OCP\TextProcessing\Events\TaskFailedEvent;
  25. use OCP\TextProcessing\Events\TaskSuccessfulEvent;
  26. use OCP\TextProcessing\FreePromptTaskType;
  27. use OCP\TextProcessing\IManager;
  28. use OCP\TextProcessing\IProvider;
  29. use OCP\TextProcessing\SummaryTaskType;
  30. use OCP\TextProcessing\Task;
  31. use OCP\TextProcessing\TopicsTaskType;
  32. use PHPUnit\Framework\Constraint\IsInstanceOf;
  33. use Psr\Log\LoggerInterface;
  34. use Test\BackgroundJob\DummyJobList;
  35. class SuccessfulSummaryProvider implements IProvider {
  36. public bool $ran = false;
  37. public function getName(): string {
  38. return 'TEST Vanilla LLM Provider';
  39. }
  40. public function process(string $prompt): string {
  41. $this->ran = true;
  42. return $prompt . ' Summarize';
  43. }
  44. public function getTaskType(): string {
  45. return SummaryTaskType::class;
  46. }
  47. }
  48. class FailingSummaryProvider implements IProvider {
  49. public bool $ran = false;
  50. public function getName(): string {
  51. return 'TEST Vanilla LLM Provider';
  52. }
  53. public function process(string $prompt): string {
  54. $this->ran = true;
  55. throw new \Exception('ERROR');
  56. }
  57. public function getTaskType(): string {
  58. return SummaryTaskType::class;
  59. }
  60. }
  61. class FreePromptProvider implements IProvider {
  62. public bool $ran = false;
  63. public function getName(): string {
  64. return 'TEST Free Prompt Provider';
  65. }
  66. public function process(string $prompt): string {
  67. $this->ran = true;
  68. return $prompt . ' Free Prompt';
  69. }
  70. public function getTaskType(): string {
  71. return FreePromptTaskType::class;
  72. }
  73. }
  74. /**
  75. * @group DB
  76. */
  77. class TextProcessingTest extends \Test\TestCase {
  78. private IManager $manager;
  79. private Coordinator $coordinator;
  80. private array $providers;
  81. private IServerContainer $serverContainer;
  82. private IEventDispatcher $eventDispatcher;
  83. private RegistrationContext $registrationContext;
  84. private \DateTimeImmutable $currentTime;
  85. private TaskMapper $taskMapper;
  86. private array $tasksDb;
  87. private IJobList $jobList;
  88. protected function setUp(): void {
  89. parent::setUp();
  90. $this->providers = [
  91. SuccessfulSummaryProvider::class => new SuccessfulSummaryProvider(),
  92. FailingSummaryProvider::class => new FailingSummaryProvider(),
  93. FreePromptProvider::class => new FreePromptProvider(),
  94. ];
  95. $this->serverContainer = $this->createMock(IServerContainer::class);
  96. $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
  97. return $this->providers[$class];
  98. });
  99. $this->eventDispatcher = new EventDispatcher(
  100. new \Symfony\Component\EventDispatcher\EventDispatcher(),
  101. $this->serverContainer,
  102. \OC::$server->get(LoggerInterface::class),
  103. );
  104. $this->registrationContext = $this->createMock(RegistrationContext::class);
  105. $this->coordinator = $this->createMock(Coordinator::class);
  106. $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);
  107. $this->currentTime = new \DateTimeImmutable('now');
  108. $this->taskMapper = $this->createMock(TaskMapper::class);
  109. $this->tasksDb = [];
  110. $this->taskMapper
  111. ->expects($this->any())
  112. ->method('insert')
  113. ->willReturnCallback(function (DbTask $task) {
  114. $task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1);
  115. $task->setLastUpdated($this->currentTime->getTimestamp());
  116. $this->tasksDb[$task->getId()] = $task->toRow();
  117. return $task;
  118. });
  119. $this->taskMapper
  120. ->expects($this->any())
  121. ->method('update')
  122. ->willReturnCallback(function (DbTask $task) {
  123. $task->setLastUpdated($this->currentTime->getTimestamp());
  124. $this->tasksDb[$task->getId()] = $task->toRow();
  125. return $task;
  126. });
  127. $this->taskMapper
  128. ->expects($this->any())
  129. ->method('find')
  130. ->willReturnCallback(function (int $id) {
  131. if (!isset($this->tasksDb[$id])) {
  132. throw new DoesNotExistException('Could not find it');
  133. }
  134. return DbTask::fromRow($this->tasksDb[$id]);
  135. });
  136. $this->taskMapper
  137. ->expects($this->any())
  138. ->method('deleteOlderThan')
  139. ->willReturnCallback(function (int $timeout) {
  140. $this->tasksDb = array_filter($this->tasksDb, function (array $task) use ($timeout) {
  141. return $task['last_updated'] >= $this->currentTime->getTimestamp() - $timeout;
  142. });
  143. });
  144. $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
  145. $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
  146. });
  147. $config = $this->createMock(IConfig::class);
  148. $config->method('getAppValue')
  149. ->with('core', 'ai.textprocessing_provider_preferences', '')
  150. ->willReturn('');
  151. $this->manager = new Manager(
  152. $this->serverContainer,
  153. $this->coordinator,
  154. \OC::$server->get(LoggerInterface::class),
  155. $this->jobList,
  156. $this->taskMapper,
  157. $config,
  158. \OC::$server->get(\OCP\TaskProcessing\IManager::class),
  159. );
  160. }
  161. public function testShouldNotHaveAnyProviders(): void {
  162. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  163. $this->assertCount(0, $this->manager->getAvailableTaskTypes());
  164. $this->assertFalse($this->manager->hasProviders());
  165. $this->expectException(PreConditionNotMetException::class);
  166. $this->manager->runTask(new \OCP\TextProcessing\Task(FreePromptTaskType::class, 'Hello', 'test', null));
  167. }
  168. public function testProviderShouldBeRegisteredAndRun(): void {
  169. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  170. new ServiceRegistration('test', SuccessfulSummaryProvider::class)
  171. ]);
  172. $this->assertCount(1, $this->manager->getAvailableTaskTypes());
  173. $this->assertTrue($this->manager->hasProviders());
  174. $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null)));
  175. // Summaries are not implemented by the vanilla provider, only free prompt
  176. $this->expectException(PreConditionNotMetException::class);
  177. $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null));
  178. }
  179. public function testProviderShouldBeRegisteredAndScheduled(): void {
  180. // register provider
  181. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  182. new ServiceRegistration('test', SuccessfulSummaryProvider::class)
  183. ]);
  184. $this->assertCount(1, $this->manager->getAvailableTaskTypes());
  185. $this->assertTrue($this->manager->hasProviders());
  186. // create task object
  187. $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
  188. $this->assertNull($task->getId());
  189. $this->assertNull($task->getOutput());
  190. // schedule works
  191. $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  192. $this->manager->scheduleTask($task);
  193. // Task object is up-to-date
  194. $this->assertNotNull($task->getId());
  195. $this->assertNull($task->getOutput());
  196. $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  197. // Task object retrieved from db is up-to-date
  198. $task2 = $this->manager->getTask($task->getId());
  199. $this->assertEquals($task->getId(), $task2->getId());
  200. $this->assertEquals('Hello', $task2->getInput());
  201. $this->assertNull($task2->getOutput());
  202. $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  203. $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
  204. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  205. // run background job
  206. $bgJob = new TaskBackgroundJob(
  207. \OC::$server->get(ITimeFactory::class),
  208. $this->manager,
  209. $this->eventDispatcher,
  210. );
  211. $bgJob->setArgument(['taskId' => $task->getId()]);
  212. $bgJob->start($this->jobList);
  213. $provider = $this->providers[SuccessfulSummaryProvider::class];
  214. $this->assertTrue($provider->ran);
  215. // Task object retrieved from db is up-to-date
  216. $task3 = $this->manager->getTask($task->getId());
  217. $this->assertEquals($task->getId(), $task3->getId());
  218. $this->assertEquals('Hello', $task3->getInput());
  219. $this->assertEquals('Hello Summarize', $task3->getOutput());
  220. $this->assertEquals(Task::STATUS_SUCCESSFUL, $task3->getStatus());
  221. }
  222. public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly(): void {
  223. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  224. new ServiceRegistration('test', SuccessfulSummaryProvider::class),
  225. new ServiceRegistration('test', FreePromptProvider::class),
  226. ]);
  227. $this->assertCount(2, $this->manager->getAvailableTaskTypes());
  228. $this->assertTrue($this->manager->hasProviders());
  229. // Try free prompt again
  230. $this->assertEquals('Hello Free Prompt', $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null)));
  231. // Try summary task
  232. $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null)));
  233. // Topics are not implemented by both the vanilla provider and the full provider
  234. $this->expectException(PreConditionNotMetException::class);
  235. $this->manager->runTask(new Task(TopicsTaskType::class, 'Hello', 'test', null));
  236. }
  237. public function testNonexistentTask(): void {
  238. $this->expectException(NotFoundException::class);
  239. $this->manager->getTask(2147483646);
  240. }
  241. public function testTaskFailure(): void {
  242. // register provider
  243. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  244. new ServiceRegistration('test', FailingSummaryProvider::class),
  245. ]);
  246. $this->assertCount(1, $this->manager->getAvailableTaskTypes());
  247. $this->assertTrue($this->manager->hasProviders());
  248. // create task object
  249. $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
  250. $this->assertNull($task->getId());
  251. $this->assertNull($task->getOutput());
  252. // schedule works
  253. $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  254. $this->manager->scheduleTask($task);
  255. // Task object is up-to-date
  256. $this->assertNotNull($task->getId());
  257. $this->assertNull($task->getOutput());
  258. $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  259. // Task object retrieved from db is up-to-date
  260. $task2 = $this->manager->getTask($task->getId());
  261. $this->assertEquals($task->getId(), $task2->getId());
  262. $this->assertEquals('Hello', $task2->getInput());
  263. $this->assertNull($task2->getOutput());
  264. $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  265. $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
  266. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  267. // run background job
  268. $bgJob = new TaskBackgroundJob(
  269. \OC::$server->get(ITimeFactory::class),
  270. $this->manager,
  271. $this->eventDispatcher,
  272. );
  273. $bgJob->setArgument(['taskId' => $task->getId()]);
  274. $bgJob->start($this->jobList);
  275. $provider = $this->providers[FailingSummaryProvider::class];
  276. $this->assertTrue($provider->ran);
  277. // Task object retrieved from db is up-to-date
  278. $task3 = $this->manager->getTask($task->getId());
  279. $this->assertEquals($task->getId(), $task3->getId());
  280. $this->assertEquals('Hello', $task3->getInput());
  281. $this->assertNull($task3->getOutput());
  282. $this->assertEquals(Task::STATUS_FAILED, $task3->getStatus());
  283. }
  284. public function testOldTasksShouldBeCleanedUp(): void {
  285. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  286. new ServiceRegistration('test', SuccessfulSummaryProvider::class)
  287. ]);
  288. $this->assertCount(1, $this->manager->getAvailableTaskTypes());
  289. $this->assertTrue($this->manager->hasProviders());
  290. $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
  291. $this->assertEquals('Hello Summarize', $this->manager->runTask($task));
  292. $this->currentTime = $this->currentTime->add(new \DateInterval('P1Y'));
  293. // run background job
  294. $bgJob = new RemoveOldTasksBackgroundJob(
  295. \OC::$server->get(ITimeFactory::class),
  296. $this->taskMapper,
  297. \OC::$server->get(LoggerInterface::class),
  298. );
  299. $bgJob->setArgument([]);
  300. $bgJob->start($this->jobList);
  301. $this->expectException(NotFoundException::class);
  302. $this->manager->getTask($task->getId());
  303. }
  304. }