diff --git a/aider_gitea/__init__.py b/aider_gitea/__init__.py index c7c6852..6345df6 100644 --- a/aider_gitea/__init__.py +++ b/aider_gitea/__init__.py @@ -70,6 +70,7 @@ The tool uses environment variables for sensitive information: ``` """ +import dataclasses import logging import re import subprocess @@ -83,6 +84,27 @@ from ._version import __version__ # noqa: F401 logger = logging.getLogger(__name__) +@dataclasses.dataclass(frozen=True) +class RepositoryConfig: + gitea_url: str + owner: str + repo: str + base_branch: str + + def repo_url(self) -> str: + return f'{self.gitea_url}:{self.owner}/{self.repo}.git'.replace( + 'https://', + 'git@', + ) + + +@dataclasses.dataclass(frozen=True) +class IssueResolution: + success: bool + pull_request_url: str | None = None + pull_request_id: str | None = None + + def generate_branch_name(issue_number: str, issue_title: str) -> str: """Create a branch name by sanitizing the issue title. @@ -197,29 +219,29 @@ def get_commit_messages(cwd: Path, base_branch: str, current_branch: str) -> lis capture_output=True, text=True, ) - return reversed(result.stdout.strip().split('\n')) + return list(reversed(result.stdout.strip().split('\n'))) except subprocess.CalledProcessError: logger.exception(f'Failed to get commit messages on branch {current_branch}') - return '' + return [] def push_changes( + repository_config: RepositoryConfig, cwd: Path, branch_name: str, issue_number: str, issue_title: str, - base_branch: str, gitea_client, - owner: str, - repo: str, -) -> bool: +) -> IssueResolution: # Check if there are any commits on the branch before pushing - if not has_commits_on_branch(cwd, base_branch, branch_name): + if not has_commits_on_branch(cwd, repository_config.base_branch, branch_name): logger.info('No commits made on branch %s, skipping push', branch_name) - return False + return IssueResolution(False) # Get commit messages for PR description - commit_messages = get_commit_messages(cwd, base_branch, branch_name) + commit_messages = get_commit_messages( + cwd, repository_config.base_branch, branch_name, + ) description = f'This pull request resolves #{issue_number}\n\n' if commit_messages: @@ -232,16 +254,20 @@ def push_changes( run_cmd(cmd, cwd) # Then create the PR with the aider label - gitea_client.create_pull_request( - owner=owner, - repo=repo, + pr_response = gitea_client.create_pull_request( + owner=repository_config.owner, + repo=repository_config.repo, title=issue_title, body=description, head=branch_name, - base=base_branch, + base=repository_config.base_branch, labels=['aider'], ) - return True + + # Extract PR number and URL if available + return IssueResolution( + True, str(pr_response.get('number')), pr_response.get('html_url'), + ) def has_commits_on_branch(cwd: Path, base_branch: str, current_branch: str) -> bool: @@ -283,28 +309,25 @@ def run_cmd(cmd: list[str], cwd: Path | None = None, check=True) -> bool: result = subprocess.run(cmd, check=check, cwd=cwd) return result.returncode == 0 + SKIP_AIDER = False + def solve_issue_in_repository( - args, + repository_config: RepositoryConfig, tmpdirname: Path, branch_name: str, issue_title: str, issue_description: str, issue_number: str, - gitea_client=None, -) -> bool: - logger.info("### %s #####", issue_title) - - repo_url = f'{args.gitea_url}:{args.owner}/{args.repo}.git'.replace( - 'https://', - 'git@', - ) + gitea_client, +) -> IssueResolution: + logger.info('### %s #####', issue_title) # Setup repository - run_cmd(['git', 'clone', repo_url, tmpdirname]) + run_cmd(['git', 'clone', repository_config.repo_url(), tmpdirname]) run_cmd(['bash', '-c', AIDER_TEST], tmpdirname) - run_cmd(['git', 'checkout', args.base_branch], tmpdirname) + run_cmd(['git', 'checkout', repository_config.base_branch], tmpdirname) run_cmd(['git', 'checkout', '-b', branch_name], tmpdirname) # Run initial ruff pass before aider @@ -331,11 +354,11 @@ def solve_issue_in_repository( check=False, ) else: - logger.warning("Skipping aider command (for testing)") + logger.warning('Skipping aider command (for testing)') succeeded = True if not succeeded: logger.error('Aider invocation failed for issue #%s', issue_number) - return False + return IssueResolution(False) # Auto-fix standard code quality stuff after aider run_cmd(['bash', '-c', RUFF_FORMAT_AND_AUTO_FIX], tmpdirname, check=False) @@ -357,52 +380,52 @@ def solve_issue_in_repository( 'Aider did not make any changes beyond the initial ruff pass for issue #%s', issue_number, ) - return False + return IssueResolution(False) # Push changes return push_changes( + repository_config, tmpdirname, branch_name, issue_number, issue_title, - args.base_branch, gitea_client, - args.owner, - args.repo, ) -def handle_issues(args, client, seen_issues_db): +def solve_issues_in_repository( + repository_config: RepositoryConfig, client, seen_issues_db, +): """Process all open issues with the 'aider' label. Args: - args: Command line arguments. + repository_config: Command line arguments. client: The Gitea client instance. seen_issues_db: Database of previously processed issues. """ try: - issues = client.get_issues(args.owner, args.repo) + issues = client.get_issues(repository_config.owner, repository_config.repo) except Exception: logger.exception('Failed to retrieve issues') sys.exit(1) if not issues: - logger.info('No issues found for %s', args.repo) + logger.info('No issues found for %s', repository_config.repo) return for issue in issues: + issue_url = issue.get('web_url') issue_number = issue.get('number') issue_description = issue.get('body', '') title = issue.get('title', f'Issue {issue_number}') - issue_text = f'{title}\n{issue_description}' - if seen_issues_db.has_seen(issue_text): + if seen_issues_db.has_seen(issue_url): logger.info('Skipping already processed issue #%s: %s', issue_number, title) continue branch_name = generate_branch_name(issue_number, title) with tempfile.TemporaryDirectory() as tmpdirname: - solved = solve_issue_in_repository( - args, + issue_resolution = solve_issue_in_repository( + repository_config, Path(tmpdirname), branch_name, title, @@ -411,5 +434,15 @@ def handle_issues(args, client, seen_issues_db): client, ) - if solved: - seen_issues_db.mark_as_seen(issue_text) + if issue_resolution.success: + seen_issues_db.mark_as_seen(issue_url, str(issue_number)) + seen_issues_db.update_pr_info( + issue_url, + issue_resolution.pull_request_id, + issue_resolution.pull_request_url, + ) + logger.info( + 'Stored PR #%s information for issue #%s', + issue_resolution.pull_request_id, + issue_number, + ) diff --git a/aider_gitea/__main__.py b/aider_gitea/__main__.py index df028f5..d2e1185 100644 --- a/aider_gitea/__main__.py +++ b/aider_gitea/__main__.py @@ -7,23 +7,14 @@ It assumes that the default branch (default "main") exists and that you have a v import argparse import logging import time -from dataclasses import dataclass -from . import handle_issues, secrets +from . import RepositoryConfig, secrets, solve_issues_in_repository from .gitea_client import GiteaClient from .seen_issues_db import SeenIssuesDB logger = logging.getLogger(__name__) -@dataclass -class AiderArgs: - gitea_url: str - owner: str - repo: str - base_branch: str - - def parse_args(): parser = argparse.ArgumentParser( description='Download issues and create pull requests for a Gitea repository.', @@ -72,13 +63,13 @@ def main(): while True: logger.info('Checking for new issues...') for repo in repositories: - aider_args = AiderArgs( + repository_config = RepositoryConfig( gitea_url=args.gitea_url, owner=args.owner, repo=repo, base_branch=args.base_branch, ) - handle_issues(aider_args, client, seen_issues_db) + solve_issues_in_repository(repository_config, client, seen_issues_db) del repo if not args.daemon: break diff --git a/aider_gitea/seen_issues_db.py b/aider_gitea/seen_issues_db.py index 85a4ae7..0a34383 100644 --- a/aider_gitea/seen_issues_db.py +++ b/aider_gitea/seen_issues_db.py @@ -1,22 +1,22 @@ -"""Database module for tracking previously processed issues. +"""Database module for tracking previously processed issues and pull requests. This module provides functionality to track which issues have already been processed by the system to avoid duplicate processing. It uses a simple SQLite database to -store hashes of seen issues for efficient lookup. +store information about seen issues and their associated pull requests for efficient lookup. """ import sqlite3 -from hashlib import sha256 DEFAULT_DB_PATH = 'output/seen_issues.db' class SeenIssuesDB: - """Database handler for tracking processed issues. + """Database handler for tracking processed issues and pull requests. - This class manages a SQLite database that stores hashes of issues that have - already been processed. It provides methods to mark issues as seen and check - if an issue has been seen before, helping to prevent duplicate processing. + This class manages a SQLite database that stores information about issues that have + already been processed and their associated pull requests. It provides methods to mark + issues as seen, check if an issue has been seen before, and retrieve pull request + information for an issue. Attributes: conn: SQLite database connection @@ -34,56 +34,90 @@ class SeenIssuesDB: def _create_table(self): """Create the seen_issues table if it doesn't exist. - Creates a table with a single column for storing issue hashes. + Creates a table with columns for storing issue hashes and associated pull request information. """ with self.conn: self.conn.execute(""" CREATE TABLE IF NOT EXISTS seen_issues ( - issue_hash TEXT PRIMARY KEY + issue_url TEXT PRIMARY KEY, + issue_number TEXT, + pr_number TEXT, + pr_url TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) - def mark_as_seen(self, issue_text: str): + def mark_as_seen( + self, + issue_url: str, + issue_number: str | None = None, + pr_number: str | None = None, + pr_url: str | None = None, + ): """Mark an issue as seen in the database. - Computes a hash of the issue text and stores it in the database. + Computes a hash of the issue text and stores it in the database along with pull request information. If the issue has already been marked as seen, this operation has no effect. Args: - issue_text: The text content of the issue to mark as seen. + issue_url: The text content of the issue to mark as seen. + issue_number: The issue number. + pr_number: The pull request number associated with this issue. + pr_url: The URL of the pull request associated with this issue. """ - issue_hash = self._compute_hash(issue_text) with self.conn: self.conn.execute( - 'INSERT OR IGNORE INTO seen_issues (issue_hash) VALUES (?)', - (issue_hash,), + 'INSERT OR IGNORE INTO seen_issues (issue_url, issue_number, pr_number, pr_url) VALUES (?, ?, ?, ?)', + (issue_url, issue_number, pr_number, pr_url), ) - def has_seen(self, issue_text: str) -> bool: + def has_seen(self, issue_url: str) -> bool: """Check if an issue has been seen before. Computes a hash of the issue text and checks if it exists in the database. Args: - issue_text: The text content of the issue to check. + issue_url: The text content of the issue to check. Returns: True if the issue has been seen before, False otherwise. """ - issue_hash = self._compute_hash(issue_text) cursor = self.conn.execute( - 'SELECT 1 FROM seen_issues WHERE issue_hash = ?', - (issue_hash,), + 'SELECT 1 FROM seen_issues WHERE issue_url = ?', + (issue_url,), ) return cursor.fetchone() is not None - def _compute_hash(self, text: str) -> str: - """Compute a SHA-256 hash of the given text. + def get_pr_info(self, issue_url: str) -> tuple[str, str] | None: + """Get pull request information for an issue. Args: - text: The text to hash. + issue_url: The text content of the issue to check. Returns: - A hexadecimal string representation of the hash. + A tuple containing (pr_number, pr_url) if found, None otherwise. """ - return sha256(text.encode('utf-8')).hexdigest() + cursor = self.conn.execute( + 'SELECT pr_number, pr_url FROM seen_issues WHERE issue_url = ?', + (issue_url,), + ) + result = cursor.fetchone() + return result if result else None + + def update_pr_info(self, issue_url: str, pr_number: str, pr_url: str) -> bool: + """Update pull request information for an existing issue. + + Args: + issue_url: The text content of the issue to update. + pr_number: The pull request number. + pr_url: The URL of the pull request. + + Returns: + True if the update was successful, False if the issue wasn't found. + """ + with self.conn: + cursor = self.conn.execute( + 'UPDATE seen_issues SET pr_number = ?, pr_url = ? WHERE issue_url = ?', + (pr_number, pr_url, issue_url), + ) + return cursor.rowcount > 0 diff --git a/test/test_seen_issues_db_pr_info.py b/test/test_seen_issues_db_pr_info.py new file mode 100644 index 0000000..7083115 --- /dev/null +++ b/test/test_seen_issues_db_pr_info.py @@ -0,0 +1,77 @@ +import os +import tempfile + +from aider_gitea.seen_issues_db import SeenIssuesDB + + +class TestSeenIssuesDBPRInfo: + def setup_method(self): + # Create a temporary database file + self.db_fd, self.db_path = tempfile.mkstemp() + self.db = SeenIssuesDB(self.db_path) + + # Test data + self.issue_text = 'Test issue title\nTest issue description' + self.issue_number = '123' + self.pr_number = '456' + self.pr_url = 'https://gitea.example.com/owner/repo/pulls/456' + + def teardown_method(self): + # Close and remove the temporary database + self.db.conn.close() + os.close(self.db_fd) + os.unlink(self.db_path) + + def test_mark_as_seen_with_pr_info(self): + # Mark an issue as seen with PR info + self.db.mark_as_seen( + self.issue_text, + issue_number=self.issue_number, + pr_number=self.pr_number, + pr_url=self.pr_url, + ) + + # Verify the issue is marked as seen + assert self.db.has_seen(self.issue_text) + + # Verify PR info is stored correctly + pr_info = self.db.get_pr_info(self.issue_text) + assert pr_info is not None + assert pr_info[0] == self.pr_number + assert pr_info[1] == self.pr_url + + def test_update_pr_info(self): + # First mark the issue as seen without PR info + self.db.mark_as_seen(self.issue_text, issue_number=self.issue_number) + + # Verify no PR info is available + assert self.db.get_pr_info(self.issue_text) == (None, None) + + # Update with PR info + updated = self.db.update_pr_info(self.issue_text, self.pr_number, self.pr_url) + + # Verify update was successful + assert updated + + # Verify PR info is now available + pr_info = self.db.get_pr_info(self.issue_text) + assert pr_info[0] == self.pr_number + assert pr_info[1] == self.pr_url + + def test_update_nonexistent_issue(self): + # Try to update PR info for an issue that doesn't exist + updated = self.db.update_pr_info( + 'Nonexistent issue', + self.pr_number, + self.pr_url, + ) + + # Verify update failed + assert not updated + + def test_get_pr_info_nonexistent(self): + # Try to get PR info for an issue that doesn't exist + pr_info = self.db.get_pr_info('Nonexistent issue') + + # Verify no PR info is available + assert pr_info is None diff --git a/test/test_solve_issue_in_repository.py b/test/test_solve_issue_in_repository.py index 0377094..4091d28 100644 --- a/test/test_solve_issue_in_repository.py +++ b/test/test_solve_issue_in_repository.py @@ -1,17 +1,18 @@ from pathlib import Path from unittest.mock import MagicMock, patch -from aider_gitea import solve_issue_in_repository +from aider_gitea import IssueResolution, RepositoryConfig, solve_issue_in_repository + +REPOSITORY_CONFIG = RepositoryConfig( + gitea_url='https://gitea.example.com', + owner='test-owner', + repo='test-repo', + base_branch='main', +) class TestSolveIssueInRepository: def setup_method(self): - self.args = MagicMock() - self.args.gitea_url = 'https://gitea.example.com' - self.args.owner = 'test-owner' - self.args.repo = 'test-repo' - self.args.base_branch = 'main' - self.gitea_client = MagicMock() self.tmpdirname = Path('/tmp/test-repo') self.branch_name = 'issue-123-test-branch' @@ -24,23 +25,32 @@ class TestSolveIssueInRepository: @patch('aider_gitea.push_changes') @patch('subprocess.run') def test_solve_issue_with_aider_changes( - self, mock_subprocess_run, mock_push_changes, mock_run_cmd, mock_llm_api_key, + self, + mock_subprocess_run, + mock_push_changes, + mock_run_cmd, + mock_llm_api_key, ): # Setup mocks mock_run_cmd.return_value = True - mock_push_changes.return_value = True + mock_push_changes.return_value = IssueResolution( + True, + '456', + 'https://gitea.example.com/test-owner/test-repo/pulls/456', + ) # Mock subprocess.run to return different commit hashes and file changes mock_subprocess_run.side_effect = [ MagicMock(stdout='abc123\n', returncode=0), # First git rev-parse MagicMock( - stdout='file1.py\nfile2.py\n', returncode=0, + stdout='file1.py\nfile2.py\n', + returncode=0, ), # git diff with changes ] # Call the function result = solve_issue_in_repository( - self.args, + REPOSITORY_CONFIG, self.tmpdirname, self.branch_name, self.issue_title, @@ -50,7 +60,7 @@ class TestSolveIssueInRepository: ) # Verify results - assert result is True + assert result.success is True assert mock_run_cmd.call_count >= 8 # Verify all expected commands were run mock_push_changes.assert_called_once() @@ -59,10 +69,15 @@ class TestSolveIssueInRepository: @patch('aider_gitea.push_changes') @patch('subprocess.run') def test_solve_issue_without_aider_changes( - self, mock_subprocess_run, mock_push_changes, mock_run_cmd, mock_llm_api_key, + self, + mock_subprocess_run, + mock_push_changes, + mock_run_cmd, + mock_llm_api_key, ): # Setup mocks mock_run_cmd.return_value = True + mock_push_changes.return_value = IssueResolution(False, None, None) # Mock subprocess.run to return same commit hash and no file changes mock_subprocess_run.side_effect = [ @@ -72,7 +87,7 @@ class TestSolveIssueInRepository: # Call the function result = solve_issue_in_repository( - self.args, + REPOSITORY_CONFIG, self.tmpdirname, self.branch_name, self.issue_title, @@ -82,5 +97,5 @@ class TestSolveIssueInRepository: ) # Verify results - assert result is False + assert result.success is False assert mock_push_changes.call_count == 0 # push_changes should not be called