* This file is licensed under the Affero General Public License version 3 or * later. * See the COPYING-README file. */ namespace Test\TextProcessing; use OC\AppFramework\Bootstrap\Coordinator; use OC\AppFramework\Bootstrap\RegistrationContext; use OC\AppFramework\Bootstrap\ServiceRegistration; use OC\EventDispatcher\EventDispatcher; use OC\TextProcessing\Db\Task as DbTask; use OC\TextProcessing\Db\TaskMapper; use OC\TextProcessing\Manager; use OC\TextProcessing\RemoveOldTasksBackgroundJob; use OC\TextProcessing\TaskBackgroundJob; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Utility\ITimeFactory; use OCP\Common\Exception\NotFoundException; use OCP\EventDispatcher\IEventDispatcher; use OCP\IServerContainer; use OCP\TextProcessing\Events\TaskFailedEvent; use OCP\TextProcessing\Events\TaskSuccessfulEvent; use OCP\TextProcessing\FreePromptTaskType; use OCP\TextProcessing\IManager; use OCP\TextProcessing\IProvider; use OCP\TextProcessing\SummaryTaskType; use OCP\PreConditionNotMetException; use OCP\TextProcessing\Task; use OCP\TextProcessing\TopicsTaskType; use PHPUnit\Framework\Constraint\IsInstanceOf; use Psr\Log\LoggerInterface; use Test\BackgroundJob\DummyJobList; class SuccessfulSummaryProvider implements IProvider { public bool $ran = false; public function getName(): string { return 'TEST Vanilla LLM Provider'; } public function process(string $prompt): string { $this->ran = true; return $prompt . ' Summarize'; } public function getTaskType(): string { return SummaryTaskType::class; } } class FailingSummaryProvider implements IProvider { public bool $ran = false; public function getName(): string { return 'TEST Vanilla LLM Provider'; } public function process(string $prompt): string { $this->ran = true; throw new \Exception('ERROR'); } public function getTaskType(): string { return SummaryTaskType::class; } } class FreePromptProvider implements IProvider { public bool $ran = false; public function getName(): string { return 'TEST Free Prompt Provider'; } public function process(string $prompt): string { $this->ran = true; return $prompt . ' Free Prompt'; } public function getTaskType(): string { return FreePromptTaskType::class; } } class TextProcessingTest extends \Test\TestCase { private IManager $manager; private Coordinator $coordinator; protected function setUp(): void { parent::setUp(); $this->providers = [ SuccessfulSummaryProvider::class => new SuccessfulSummaryProvider(), FailingSummaryProvider::class => new FailingSummaryProvider(), FreePromptProvider::class => new FreePromptProvider(), ]; $this->serverContainer = $this->createMock(IServerContainer::class); $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) { return $this->providers[$class]; }); $this->eventDispatcher = new EventDispatcher( new \Symfony\Component\EventDispatcher\EventDispatcher(), $this->serverContainer, \OC::$server->get(LoggerInterface::class), ); $this->registrationContext = $this->createMock(RegistrationContext::class); $this->coordinator = $this->createMock(Coordinator::class); $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext); $this->currentTime = new \DateTimeImmutable('now'); $this->taskMapper = $this->createMock(TaskMapper::class); $this->tasksDb = []; $this->taskMapper ->expects($this->any()) ->method('insert') ->willReturnCallback(function (DbTask $task) { $task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1); $task->setLastUpdated($this->currentTime->getTimestamp()); $this->tasksDb[$task->getId()] = $task->toRow(); return $task; }); $this->taskMapper ->expects($this->any()) ->method('update') ->willReturnCallback(function (DbTask $task) { $task->setLastUpdated($this->currentTime->getTimestamp()); $this->tasksDb[$task->getId()] = $task->toRow(); return $task; }); $this->taskMapper ->expects($this->any()) ->method('find') ->willReturnCallback(function (int $id) { if (!isset($this->tasksDb[$id])) { throw new DoesNotExistException('Could not find it'); } return DbTask::fromRow($this->tasksDb[$id]); }); $this->taskMapper ->expects($this->any()) ->method('deleteOlderThan') ->willReturnCallback(function (int $timeout) { $this->tasksDb = array_filter($this->tasksDb, function (array $task) use ($timeout) { return $task['last_updated'] >= $this->currentTime->getTimestamp() - $timeout; }); }); $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']); $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () { }); $this->manager = new Manager( $this->serverContainer, $this->coordinator, \OC::$server->get(LoggerInterface::class), $this->jobList, $this->taskMapper, ); } public function testShouldNotHaveAnyProviders() { $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); $this->assertCount(0, $this->manager->getAvailableTaskTypes()); $this->assertFalse($this->manager->hasProviders()); $this->expectException(PreConditionNotMetException::class); $this->manager->runTask(new \OCP\TextProcessing\Task(FreePromptTaskType::class, 'Hello', 'test', null)); } public function testProviderShouldBeRegisteredAndRun() { $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ new ServiceRegistration('test', SuccessfulSummaryProvider::class) ]); $this->assertCount(1, $this->manager->getAvailableTaskTypes()); $this->assertTrue($this->manager->hasProviders()); $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null))); // Summaries are not implemented by the vanilla provider, only free prompt $this->expectException(PreConditionNotMetException::class); $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null)); } public function testProviderShouldBeRegisteredAndScheduled() { // register provider $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ new ServiceRegistration('test', SuccessfulSummaryProvider::class) ]); $this->assertCount(1, $this->manager->getAvailableTaskTypes()); $this->assertTrue($this->manager->hasProviders()); // create task object $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); $this->assertNull($task->getId()); $this->assertNull($task->getOutput()); // schedule works $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); $this->manager->scheduleTask($task); // Task object is up-to-date $this->assertNotNull($task->getId()); $this->assertNull($task->getOutput()); $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); // Task object retrieved from db is up-to-date $task2 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task2->getId()); $this->assertEquals('Hello', $task2->getInput()); $this->assertNull($task2->getOutput()); $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); $this->eventDispatcher = $this->createMock(IEventDispatcher::class); $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); // run background job $bgJob = new TaskBackgroundJob( \OC::$server->get(ITimeFactory::class), $this->manager, $this->eventDispatcher, ); $bgJob->setArgument(['taskId' => $task->getId()]); $bgJob->start($this->jobList); $provider = $this->providers[SuccessfulSummaryProvider::class]; $this->assertTrue($provider->ran); // Task object retrieved from db is up-to-date $task3 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task3->getId()); $this->assertEquals('Hello', $task3->getInput()); $this->assertEquals('Hello Summarize', $task3->getOutput()); $this->assertEquals(Task::STATUS_SUCCESSFUL, $task3->getStatus()); } public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ new ServiceRegistration('test', SuccessfulSummaryProvider::class), new ServiceRegistration('test', FreePromptProvider::class), ]); $this->assertCount(2, $this->manager->getAvailableTaskTypes()); $this->assertTrue($this->manager->hasProviders()); // Try free prompt again $this->assertEquals('Hello Free Prompt', $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null))); // Try summary task $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null))); // Topics are not implemented by both the vanilla provider and the full provider $this->expectException(PreConditionNotMetException::class); $this->manager->runTask(new Task(TopicsTaskType::class, 'Hello', 'test', null)); } public function testNonexistentTask() { $this->expectException(NotFoundException::class); $this->manager->getTask(98765432456); } public function testTaskFailure() { // register provider $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ new ServiceRegistration('test', FailingSummaryProvider::class), ]); $this->assertCount(1, $this->manager->getAvailableTaskTypes()); $this->assertTrue($this->manager->hasProviders()); // create task object $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); $this->assertNull($task->getId()); $this->assertNull($task->getOutput()); // schedule works $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); $this->manager->scheduleTask($task); // Task object is up-to-date $this->assertNotNull($task->getId()); $this->assertNull($task->getOutput()); $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); // Task object retrieved from db is up-to-date $task2 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task2->getId()); $this->assertEquals('Hello', $task2->getInput()); $this->assertNull($task2->getOutput()); $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); $this->eventDispatcher = $this->createMock(IEventDispatcher::class); $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); // run background job $bgJob = new TaskBackgroundJob( \OC::$server->get(ITimeFactory::class), $this->manager, $this->eventDispatcher, ); $bgJob->setArgument(['taskId' => $task->getId()]); $bgJob->start($this->jobList); $provider = $this->providers[FailingSummaryProvider::class]; $this->assertTrue($provider->ran); // Task object retrieved from db is up-to-date $task3 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task3->getId()); $this->assertEquals('Hello', $task3->getInput()); $this->assertNull($task3->getOutput()); $this->assertEquals(Task::STATUS_FAILED, $task3->getStatus()); } public function testOldTasksShouldBeCleanedUp() { $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ new ServiceRegistration('test', SuccessfulSummaryProvider::class) ]); $this->assertCount(1, $this->manager->getAvailableTaskTypes()); $this->assertTrue($this->manager->hasProviders()); $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); $this->assertEquals('Hello Summarize', $this->manager->runTask($task)); $this->currentTime = $this->currentTime->add(new \DateInterval('P1Y')); // run background job $bgJob = new RemoveOldTasksBackgroundJob( \OC::$server->get(ITimeFactory::class), $this->taskMapper, \OC::$server->get(LoggerInterface::class), ); $bgJob->setArgument([]); $bgJob->start($this->jobList); $this->expectException(NotFoundException::class); $this->manager->getTask($task->getId()); } }