diff --git a/aider_gitea/__init__.py b/aider_gitea/__init__.py index 7344de0..7795aac 100644 --- a/aider_gitea/__init__.py +++ b/aider_gitea/__init__.py @@ -205,6 +205,9 @@ def push_changes( issue_number: str, issue_title: str, base_branch: str, + gitea_client=None, + owner=None, + repo=None, ) -> bool: # Check if there are any commits on the branch before pushing if not has_commits_on_branch(cwd, base_branch, branch_name): @@ -220,22 +223,40 @@ def push_changes( for message in commit_messages.split('\n'): description += f'- {message}\n' - description = description.replace('\n', '
') - - cmd = [ - 'git', - 'push', - 'origin', - f'HEAD:refs/for/{base_branch}', - '-o', - f'topic={branch_name}', - '-o', - f'title={issue_title}', - '-o', - f'description="{description}"', - ] - run_cmd(cmd, cwd) - return True + # If we have a Gitea client, create the PR with the aider label + if gitea_client and owner and repo: + # First push the branch without creating a PR + cmd = ['git', 'push', 'origin', branch_name] + run_cmd(cmd, cwd) + + # Then create the PR with the aider label + gitea_client.create_pull_request( + owner=owner, + repo=repo, + title=issue_title, + body=description, + head=branch_name, + base=base_branch, + labels=['aider'] + ) + return True + else: + # Fall back to the original method if no Gitea client is provided + description_formatted = description.replace('\n', '
') + cmd = [ + 'git', + 'push', + 'origin', + f'HEAD:refs/for/{base_branch}', + '-o', + f'topic={branch_name}', + '-o', + f'title={issue_title}', + '-o', + f'description="{description_formatted}"', + ] + run_cmd(cmd, cwd) + return True def has_commits_on_branch(cwd: Path, base_branch: str, current_branch: str) -> bool: @@ -285,6 +306,7 @@ def solve_issue_in_repository( issue_title: str, issue_description: str, issue_number: str, + gitea_client=None, ) -> bool: repo_url = f'{args.gitea_url}:{args.owner}/{args.repo}.git'.replace( 'https://', @@ -320,6 +342,9 @@ def solve_issue_in_repository( issue_number, issue_title, args.base_branch, + gitea_client, + args.owner, + args.repo, ) @@ -359,6 +384,7 @@ def handle_issues(args, client, seen_issues_db): title, issue_description, issue_number, + client, ) if solved: diff --git a/aider_gitea/gitea_client.py b/aider_gitea/gitea_client.py index 0ebeb9f..c712a4f 100644 --- a/aider_gitea/gitea_client.py +++ b/aider_gitea/gitea_client.py @@ -129,3 +129,76 @@ class GiteaClient: if repo['owner']['login'].lower() != owner.lower(): continue yield repo['name'] + + def create_pull_request( + self, + owner: str, + repo: str, + title: str, + body: str, + head: str, + base: str, + labels: list[str] = None + ) -> dict: + """Create a pull request and optionally apply labels. + + Args: + owner (str): Owner of the repository. + repo (str): Name of the repository. + title (str): Title of the pull request. + body (str): Description/body of the pull request. + head (str): The name of the branch where changes are implemented. + base (str): The name of the branch you want the changes pulled into. + labels (list[str], optional): List of label names to apply to the pull request. + + Returns: + dict: The created pull request data. + + Raises: + requests.HTTPError: If the API request fails. + """ + url = f'{self.gitea_url}/repos/{owner}/{repo}/pulls' + json_data = { + 'title': title, + 'body': body, + 'head': head, + 'base': base + } + + response = self.session.post(url, json=json_data) + response.raise_for_status() + pull_request = response.json() + + # Apply labels if provided + if labels and pull_request.get('number'): + self.add_labels_to_pull_request(owner, repo, pull_request['number'], labels) + + return pull_request + + def add_labels_to_pull_request( + self, + owner: str, + repo: str, + pull_number: int, + labels: list[str] + ) -> bool: + """Add labels to an existing pull request. + + Args: + owner (str): Owner of the repository. + repo (str): Name of the repository. + pull_number (int): The pull request number. + labels (list[str]): List of label names to apply. + + Returns: + bool: True if labels were successfully applied. + + Raises: + requests.HTTPError: If the API request fails. + """ + url = f'{self.gitea_url}/repos/{owner}/{repo}/issues/{pull_number}/labels' + json_data = {'labels': labels} + + response = self.session.post(url, json=json_data) + response.raise_for_status() + return True diff --git a/test/test_gitea_client_pr_labels.py b/test/test_gitea_client_pr_labels.py new file mode 100644 index 0000000..31dfc20 --- /dev/null +++ b/test/test_gitea_client_pr_labels.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import MagicMock, patch +from aider_gitea.gitea_client import GiteaClient + + +class TestGiteaClientPRLabels: + def setup_method(self): + self.client = GiteaClient("https://gitea.example.com", "fake_token") + + @patch('requests.Session.post') + def test_create_pull_request_with_labels(self, mock_post): + # Mock the PR creation response + pr_response = MagicMock() + pr_response.status_code = 201 + pr_response.json.return_value = { + 'number': 123, + 'title': 'Test PR', + 'html_url': 'https://gitea.example.com/owner/repo/pulls/123' + } + + # Mock the label addition response + label_response = MagicMock() + label_response.status_code = 200 + + # Set up the mock to return different responses for different calls + mock_post.side_effect = [pr_response, label_response] + + # Call the method with labels + result = self.client.create_pull_request( + owner="owner", + repo="repo", + title="Test PR", + body="Test body", + head="feature-branch", + base="main", + labels=["aider"] + ) + + # Verify PR creation call + assert mock_post.call_count == 2 + pr_call_args = mock_post.call_args_list[0] + assert pr_call_args[0][0] == 'https://gitea.example.com/api/v1/repos/owner/repo/pulls' + assert pr_call_args[1]['json']['title'] == 'Test PR' + + # Verify label addition call + label_call_args = mock_post.call_args_list[1] + assert label_call_args[0][0] == 'https://gitea.example.com/api/v1/repos/owner/repo/issues/123/labels' + assert label_call_args[1]['json']['labels'] == ['aider'] + + # Verify the result + assert result['number'] == 123 + assert result['title'] == 'Test PR' + + @patch('requests.Session.post') + def test_add_labels_to_pull_request(self, mock_post): + # Mock the response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + # Call the method + result = self.client.add_labels_to_pull_request( + owner="owner", + repo="repo", + pull_number=123, + labels=["aider", "bug"] + ) + + # Verify the call + mock_post.assert_called_once() + call_args = mock_post.call_args + assert call_args[0][0] == 'https://gitea.example.com/api/v1/repos/owner/repo/issues/123/labels' + assert call_args[1]['json']['labels'] == ['aider', 'bug'] + + # Verify the result + assert result is True