diff --git a/aider_gitea/__init__.py b/aider_gitea/__init__.py index 2ba421e..ef6d4c9 100644 --- a/aider_gitea/__init__.py +++ b/aider_gitea/__init__.py @@ -212,11 +212,13 @@ def push_changes( gitea_client, owner: str, repo: str, -) -> bool: + seen_issues_db=None, + issue_text: str = None, +) -> tuple[bool, str, str]: # Check if there are any commits on the branch before pushing if not has_commits_on_branch(cwd, base_branch, branch_name): logger.info('No commits made on branch %s, skipping push', branch_name) - return False + return False, None, None # Get commit messages for PR description commit_messages = get_commit_messages(cwd, base_branch, branch_name) @@ -232,7 +234,7 @@ def push_changes( run_cmd(cmd, cwd) # Then create the PR with the aider label - gitea_client.create_pull_request( + pr_response = gitea_client.create_pull_request( owner=owner, repo=repo, title=issue_title, @@ -241,7 +243,20 @@ def push_changes( base=base_branch, labels=['aider'], ) - return True + + # Extract PR number and URL if available + pr_number = None + pr_url = None + if pr_response and isinstance(pr_response, dict): + pr_number = str(pr_response.get('number')) + pr_url = pr_response.get('html_url') + + # Store PR information in the database if available + if seen_issues_db and issue_text and pr_number and pr_url: + seen_issues_db.update_pr_info(issue_text, pr_number, pr_url) + logger.info('Stored PR #%s information for issue #%s', pr_number, issue_number) + + return True, pr_number, pr_url def has_commits_on_branch(cwd: Path, base_branch: str, current_branch: str) -> bool: @@ -295,6 +310,7 @@ def solve_issue_in_repository( issue_description: str, issue_number: str, gitea_client=None, + seen_issues_db=None, ) -> bool: logger.info('### %s #####', issue_title) @@ -361,8 +377,11 @@ def solve_issue_in_repository( ) return False + # Create issue_text for database tracking + issue_text = f'{issue_title}\n{issue_description}' + # Push changes - return push_changes( + success, pr_number, pr_url = push_changes( tmpdirname, branch_name, issue_number, @@ -371,7 +390,11 @@ def solve_issue_in_repository( gitea_client, args.owner, args.repo, + seen_issues_db, + issue_text, ) + + return success def handle_issues(args, client, seen_issues_db): @@ -411,7 +434,8 @@ def handle_issues(args, client, seen_issues_db): issue_description, issue_number, client, + seen_issues_db, ) if solved: - seen_issues_db.mark_as_seen(issue_text) + seen_issues_db.mark_as_seen(issue_text, str(issue_number)) diff --git a/aider_gitea/seen_issues_db.py b/aider_gitea/seen_issues_db.py index 85a4ae7..84825c7 100644 --- a/aider_gitea/seen_issues_db.py +++ b/aider_gitea/seen_issues_db.py @@ -1,22 +1,24 @@ -"""Database module for tracking previously processed issues. +"""Database module for tracking previously processed issues and pull requests. This module provides functionality to track which issues have already been processed by the system to avoid duplicate processing. It uses a simple SQLite database to -store hashes of seen issues for efficient lookup. +store information about seen issues and their associated pull requests for efficient lookup. """ import sqlite3 from hashlib import sha256 +from typing import Optional, Tuple DEFAULT_DB_PATH = 'output/seen_issues.db' class SeenIssuesDB: - """Database handler for tracking processed issues. + """Database handler for tracking processed issues and pull requests. - This class manages a SQLite database that stores hashes of issues that have - already been processed. It provides methods to mark issues as seen and check - if an issue has been seen before, helping to prevent duplicate processing. + This class manages a SQLite database that stores information about issues that have + already been processed and their associated pull requests. It provides methods to mark + issues as seen, check if an issue has been seen before, and retrieve pull request + information for an issue. Attributes: conn: SQLite database connection @@ -34,29 +36,36 @@ class SeenIssuesDB: def _create_table(self): """Create the seen_issues table if it doesn't exist. - Creates a table with a single column for storing issue hashes. + Creates a table with columns for storing issue hashes and associated pull request information. """ with self.conn: self.conn.execute(""" CREATE TABLE IF NOT EXISTS seen_issues ( - issue_hash TEXT PRIMARY KEY + issue_hash TEXT PRIMARY KEY, + issue_number TEXT, + pr_number TEXT, + pr_url TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) - def mark_as_seen(self, issue_text: str): + def mark_as_seen(self, issue_text: str, issue_number: str = None, pr_number: str = None, pr_url: str = None): """Mark an issue as seen in the database. - Computes a hash of the issue text and stores it in the database. + Computes a hash of the issue text and stores it in the database along with pull request information. If the issue has already been marked as seen, this operation has no effect. Args: issue_text: The text content of the issue to mark as seen. + issue_number: The issue number. + pr_number: The pull request number associated with this issue. + pr_url: The URL of the pull request associated with this issue. """ issue_hash = self._compute_hash(issue_text) with self.conn: self.conn.execute( - 'INSERT OR IGNORE INTO seen_issues (issue_hash) VALUES (?)', - (issue_hash,), + 'INSERT OR IGNORE INTO seen_issues (issue_hash, issue_number, pr_number, pr_url) VALUES (?, ?, ?, ?)', + (issue_hash, issue_number, pr_number, pr_url), ) def has_seen(self, issue_text: str) -> bool: @@ -77,6 +86,42 @@ class SeenIssuesDB: ) return cursor.fetchone() is not None + def get_pr_info(self, issue_text: str) -> Optional[Tuple[str, str]]: + """Get pull request information for an issue. + + Args: + issue_text: The text content of the issue to check. + + Returns: + A tuple containing (pr_number, pr_url) if found, None otherwise. + """ + issue_hash = self._compute_hash(issue_text) + cursor = self.conn.execute( + 'SELECT pr_number, pr_url FROM seen_issues WHERE issue_hash = ?', + (issue_hash,), + ) + result = cursor.fetchone() + return result if result else None + + def update_pr_info(self, issue_text: str, pr_number: str, pr_url: str) -> bool: + """Update pull request information for an existing issue. + + Args: + issue_text: The text content of the issue to update. + pr_number: The pull request number. + pr_url: The URL of the pull request. + + Returns: + True if the update was successful, False if the issue wasn't found. + """ + issue_hash = self._compute_hash(issue_text) + with self.conn: + cursor = self.conn.execute( + 'UPDATE seen_issues SET pr_number = ?, pr_url = ? WHERE issue_hash = ?', + (pr_number, pr_url, issue_hash), + ) + return cursor.rowcount > 0 + def _compute_hash(self, text: str) -> str: """Compute a SHA-256 hash of the given text. diff --git a/test/test_seen_issues_db_pr_info.py b/test/test_seen_issues_db_pr_info.py new file mode 100644 index 0000000..a784212 --- /dev/null +++ b/test/test_seen_issues_db_pr_info.py @@ -0,0 +1,84 @@ +import os +import tempfile +from pathlib import Path + +import pytest + +from aider_gitea.seen_issues_db import SeenIssuesDB + + +class TestSeenIssuesDBPRInfo: + def setup_method(self): + # Create a temporary database file + self.db_fd, self.db_path = tempfile.mkstemp() + self.db = SeenIssuesDB(self.db_path) + + # Test data + self.issue_text = "Test issue title\nTest issue description" + self.issue_number = "123" + self.pr_number = "456" + self.pr_url = "https://gitea.example.com/owner/repo/pulls/456" + + def teardown_method(self): + # Close and remove the temporary database + self.db.conn.close() + os.close(self.db_fd) + os.unlink(self.db_path) + + def test_mark_as_seen_with_pr_info(self): + # Mark an issue as seen with PR info + self.db.mark_as_seen( + self.issue_text, + issue_number=self.issue_number, + pr_number=self.pr_number, + pr_url=self.pr_url + ) + + # Verify the issue is marked as seen + assert self.db.has_seen(self.issue_text) + + # Verify PR info is stored correctly + pr_info = self.db.get_pr_info(self.issue_text) + assert pr_info is not None + assert pr_info[0] == self.pr_number + assert pr_info[1] == self.pr_url + + def test_update_pr_info(self): + # First mark the issue as seen without PR info + self.db.mark_as_seen(self.issue_text, issue_number=self.issue_number) + + # Verify no PR info is available + assert self.db.get_pr_info(self.issue_text) == (None, None) + + # Update with PR info + updated = self.db.update_pr_info( + self.issue_text, + self.pr_number, + self.pr_url + ) + + # Verify update was successful + assert updated + + # Verify PR info is now available + pr_info = self.db.get_pr_info(self.issue_text) + assert pr_info[0] == self.pr_number + assert pr_info[1] == self.pr_url + + def test_update_nonexistent_issue(self): + # Try to update PR info for an issue that doesn't exist + updated = self.db.update_pr_info( + "Nonexistent issue", + self.pr_number, + self.pr_url + ) + + # Verify update failed + assert not updated + + def test_get_pr_info_nonexistent(self): + # Try to get PR info for an issue that doesn't exist + pr_info = self.db.get_pr_info("Nonexistent issue") + + # Verify no PR info is available + assert pr_info is None