123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- <?php
- /**
- * Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
- * This file is licensed under the Affero General Public License version 3 or
- * later.
- * See the COPYING-README file.
- */
- namespace Test\TextProcessing;
- use OC\AppFramework\Bootstrap\Coordinator;
- use OC\AppFramework\Bootstrap\RegistrationContext;
- use OC\AppFramework\Bootstrap\ServiceRegistration;
- use OC\EventDispatcher\EventDispatcher;
- use OC\TextProcessing\Db\Task as DbTask;
- use OC\TextProcessing\Db\TaskMapper;
- use OC\TextProcessing\Manager;
- use OC\TextProcessing\RemoveOldTasksBackgroundJob;
- use OC\TextProcessing\TaskBackgroundJob;
- use OCP\AppFramework\Db\DoesNotExistException;
- use OCP\AppFramework\Utility\ITimeFactory;
- use OCP\Common\Exception\NotFoundException;
- use OCP\EventDispatcher\IEventDispatcher;
- use OCP\IServerContainer;
- use OCP\TextProcessing\Events\TaskFailedEvent;
- use OCP\TextProcessing\Events\TaskSuccessfulEvent;
- use OCP\TextProcessing\FreePromptTaskType;
- use OCP\TextProcessing\IManager;
- use OCP\TextProcessing\IProvider;
- use OCP\TextProcessing\SummaryTaskType;
- use OCP\PreConditionNotMetException;
- use OCP\TextProcessing\Task;
- use OCP\TextProcessing\TopicsTaskType;
- use PHPUnit\Framework\Constraint\IsInstanceOf;
- use Psr\Log\LoggerInterface;
- use Test\BackgroundJob\DummyJobList;
- class SuccessfulSummaryProvider implements IProvider {
- public bool $ran = false;
- public function getName(): string {
- return 'TEST Vanilla LLM Provider';
- }
- public function process(string $prompt): string {
- $this->ran = true;
- return $prompt . ' Summarize';
- }
- public function getTaskType(): string {
- return SummaryTaskType::class;
- }
- }
- class FailingSummaryProvider implements IProvider {
- public bool $ran = false;
- public function getName(): string {
- return 'TEST Vanilla LLM Provider';
- }
- public function process(string $prompt): string {
- $this->ran = true;
- throw new \Exception('ERROR');
- }
- public function getTaskType(): string {
- return SummaryTaskType::class;
- }
- }
- class FreePromptProvider implements IProvider {
- public bool $ran = false;
- public function getName(): string {
- return 'TEST Free Prompt Provider';
- }
- public function process(string $prompt): string {
- $this->ran = true;
- return $prompt . ' Free Prompt';
- }
- public function getTaskType(): string {
- return FreePromptTaskType::class;
- }
- }
- class TextProcessingTest extends \Test\TestCase {
- private IManager $manager;
- private Coordinator $coordinator;
- protected function setUp(): void {
- parent::setUp();
- $this->providers = [
- SuccessfulSummaryProvider::class => new SuccessfulSummaryProvider(),
- FailingSummaryProvider::class => new FailingSummaryProvider(),
- FreePromptProvider::class => new FreePromptProvider(),
- ];
- $this->serverContainer = $this->createMock(IServerContainer::class);
- $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
- return $this->providers[$class];
- });
- $this->eventDispatcher = new EventDispatcher(
- new \Symfony\Component\EventDispatcher\EventDispatcher(),
- $this->serverContainer,
- \OC::$server->get(LoggerInterface::class),
- );
- $this->registrationContext = $this->createMock(RegistrationContext::class);
- $this->coordinator = $this->createMock(Coordinator::class);
- $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);
- $this->currentTime = new \DateTimeImmutable('now');
- $this->taskMapper = $this->createMock(TaskMapper::class);
- $this->tasksDb = [];
- $this->taskMapper
- ->expects($this->any())
- ->method('insert')
- ->willReturnCallback(function (DbTask $task) {
- $task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1);
- $task->setLastUpdated($this->currentTime->getTimestamp());
- $this->tasksDb[$task->getId()] = $task->toRow();
- return $task;
- });
- $this->taskMapper
- ->expects($this->any())
- ->method('update')
- ->willReturnCallback(function (DbTask $task) {
- $task->setLastUpdated($this->currentTime->getTimestamp());
- $this->tasksDb[$task->getId()] = $task->toRow();
- return $task;
- });
- $this->taskMapper
- ->expects($this->any())
- ->method('find')
- ->willReturnCallback(function (int $id) {
- if (!isset($this->tasksDb[$id])) {
- throw new DoesNotExistException('Could not find it');
- }
- return DbTask::fromRow($this->tasksDb[$id]);
- });
- $this->taskMapper
- ->expects($this->any())
- ->method('deleteOlderThan')
- ->willReturnCallback(function (int $timeout) {
- $this->tasksDb = array_filter($this->tasksDb, function (array $task) use ($timeout) {
- return $task['last_updated'] >= $this->currentTime->getTimestamp() - $timeout;
- });
- });
- $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
- $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
- });
- $this->manager = new Manager(
- $this->serverContainer,
- $this->coordinator,
- \OC::$server->get(LoggerInterface::class),
- $this->jobList,
- $this->taskMapper,
- );
- }
- public function testShouldNotHaveAnyProviders() {
- $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
- $this->assertCount(0, $this->manager->getAvailableTaskTypes());
- $this->assertFalse($this->manager->hasProviders());
- $this->expectException(PreConditionNotMetException::class);
- $this->manager->runTask(new \OCP\TextProcessing\Task(FreePromptTaskType::class, 'Hello', 'test', null));
- }
- public function testProviderShouldBeRegisteredAndRun() {
- $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
- new ServiceRegistration('test', SuccessfulSummaryProvider::class)
- ]);
- $this->assertCount(1, $this->manager->getAvailableTaskTypes());
- $this->assertTrue($this->manager->hasProviders());
- $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null)));
- // Summaries are not implemented by the vanilla provider, only free prompt
- $this->expectException(PreConditionNotMetException::class);
- $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null));
- }
- public function testProviderShouldBeRegisteredAndScheduled() {
- // register provider
- $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
- new ServiceRegistration('test', SuccessfulSummaryProvider::class)
- ]);
- $this->assertCount(1, $this->manager->getAvailableTaskTypes());
- $this->assertTrue($this->manager->hasProviders());
- // create task object
- $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
- $this->assertNull($task->getId());
- $this->assertNull($task->getOutput());
- // schedule works
- $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
- $this->manager->scheduleTask($task);
- // Task object is up-to-date
- $this->assertNotNull($task->getId());
- $this->assertNull($task->getOutput());
- $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
- // Task object retrieved from db is up-to-date
- $task2 = $this->manager->getTask($task->getId());
- $this->assertEquals($task->getId(), $task2->getId());
- $this->assertEquals('Hello', $task2->getInput());
- $this->assertNull($task2->getOutput());
- $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
- $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
- $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
- // run background job
- $bgJob = new TaskBackgroundJob(
- \OC::$server->get(ITimeFactory::class),
- $this->manager,
- $this->eventDispatcher,
- );
- $bgJob->setArgument(['taskId' => $task->getId()]);
- $bgJob->start($this->jobList);
- $provider = $this->providers[SuccessfulSummaryProvider::class];
- $this->assertTrue($provider->ran);
- // Task object retrieved from db is up-to-date
- $task3 = $this->manager->getTask($task->getId());
- $this->assertEquals($task->getId(), $task3->getId());
- $this->assertEquals('Hello', $task3->getInput());
- $this->assertEquals('Hello Summarize', $task3->getOutput());
- $this->assertEquals(Task::STATUS_SUCCESSFUL, $task3->getStatus());
- }
- public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() {
- $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
- new ServiceRegistration('test', SuccessfulSummaryProvider::class),
- new ServiceRegistration('test', FreePromptProvider::class),
- ]);
- $this->assertCount(2, $this->manager->getAvailableTaskTypes());
- $this->assertTrue($this->manager->hasProviders());
- // Try free prompt again
- $this->assertEquals('Hello Free Prompt', $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null)));
- // Try summary task
- $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null)));
- // Topics are not implemented by both the vanilla provider and the full provider
- $this->expectException(PreConditionNotMetException::class);
- $this->manager->runTask(new Task(TopicsTaskType::class, 'Hello', 'test', null));
- }
- public function testNonexistentTask() {
- $this->expectException(NotFoundException::class);
- $this->manager->getTask(98765432456);
- }
- public function testTaskFailure() {
- // register provider
- $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
- new ServiceRegistration('test', FailingSummaryProvider::class),
- ]);
- $this->assertCount(1, $this->manager->getAvailableTaskTypes());
- $this->assertTrue($this->manager->hasProviders());
- // create task object
- $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
- $this->assertNull($task->getId());
- $this->assertNull($task->getOutput());
- // schedule works
- $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
- $this->manager->scheduleTask($task);
- // Task object is up-to-date
- $this->assertNotNull($task->getId());
- $this->assertNull($task->getOutput());
- $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
- // Task object retrieved from db is up-to-date
- $task2 = $this->manager->getTask($task->getId());
- $this->assertEquals($task->getId(), $task2->getId());
- $this->assertEquals('Hello', $task2->getInput());
- $this->assertNull($task2->getOutput());
- $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
- $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
- $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
- // run background job
- $bgJob = new TaskBackgroundJob(
- \OC::$server->get(ITimeFactory::class),
- $this->manager,
- $this->eventDispatcher,
- );
- $bgJob->setArgument(['taskId' => $task->getId()]);
- $bgJob->start($this->jobList);
- $provider = $this->providers[FailingSummaryProvider::class];
- $this->assertTrue($provider->ran);
- // Task object retrieved from db is up-to-date
- $task3 = $this->manager->getTask($task->getId());
- $this->assertEquals($task->getId(), $task3->getId());
- $this->assertEquals('Hello', $task3->getInput());
- $this->assertNull($task3->getOutput());
- $this->assertEquals(Task::STATUS_FAILED, $task3->getStatus());
- }
- public function testOldTasksShouldBeCleanedUp() {
- $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
- new ServiceRegistration('test', SuccessfulSummaryProvider::class)
- ]);
- $this->assertCount(1, $this->manager->getAvailableTaskTypes());
- $this->assertTrue($this->manager->hasProviders());
- $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
- $this->assertEquals('Hello Summarize', $this->manager->runTask($task));
- $this->currentTime = $this->currentTime->add(new \DateInterval('P1Y'));
- // run background job
- $bgJob = new RemoveOldTasksBackgroundJob(
- \OC::$server->get(ITimeFactory::class),
- $this->taskMapper,
- \OC::$server->get(LoggerInterface::class),
- );
- $bgJob->setArgument([]);
- $bgJob->start($this->jobList);
- $this->expectException(NotFoundException::class);
- $this->manager->getTask($task->getId());
- }
- }
|