316 lines
11 KiB
Python
316 lines
11 KiB
Python
import logging
|
||
from collections.abc import Iterator
|
||
from typing import Any
|
||
|
||
import requests
|
||
|
||
from .models import (
|
||
GiteaIssue,
|
||
GiteaLabel,
|
||
GiteaUser,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Module-level constants
|
||
API_VERSION_PATH = '/api/v1'
|
||
DEFAULT_CONTENT_TYPE = 'application/json'
|
||
AIDER_LABEL_NAME = 'aider'
|
||
SUCCESS_CONCLUSION = 'success'
|
||
CONFLICT_STATUS_CODE = 409
|
||
UNPROCESSABLE_ENTITY_STATUS_CODE = 422
|
||
|
||
|
||
class GiteaClient:
|
||
"""Client for interacting with the Gitea API.
|
||
|
||
This class provides methods to interact with a Gitea instance's API,
|
||
including retrieving repository information, creating branches, and fetching issues.
|
||
|
||
Read more about the Gitea API here: https://gitea.com/api/swagger
|
||
|
||
Follows the standardized client format:
|
||
1. Constructor takes a requests.Session object
|
||
2. All secrets are provided via keyword arguments
|
||
3. ROOT_URL constant field for constructing URLs
|
||
|
||
Attributes:
|
||
session (requests.Session): HTTP session for making API requests.
|
||
ROOT_URL (str): The base URL for the Gitea API endpoints.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session: requests.Session,
|
||
*,
|
||
gitea_url: str,
|
||
token: str = '',
|
||
) -> None:
|
||
"""Initialize a new Gitea API client.
|
||
|
||
Args:
|
||
session: HTTP session object to use for requests.
|
||
gitea_url: Base URL for the Gitea instance (without '/api/v1').
|
||
token: Authentication token for the Gitea API. If empty, requests will be unauthenticated.
|
||
|
||
Raises:
|
||
AssertionError: If gitea_url ends with '/api/v1'.
|
||
"""
|
||
assert not gitea_url.endswith(API_VERSION_PATH)
|
||
self.session = session
|
||
self.ROOT_URL = gitea_url + API_VERSION_PATH
|
||
self.session.headers['Content-Type'] = DEFAULT_CONTENT_TYPE
|
||
if token:
|
||
self.session.headers['Authorization'] = f'token {token}'
|
||
|
||
def get_default_branch_sha(self, owner: str, repo: str, branch_name: str) -> str:
|
||
"""Retrieve the commit SHA of the specified branch.
|
||
|
||
Args:
|
||
owner: Owner of the repository.
|
||
repo: Name of the repository.
|
||
branch_name: Name of the branch.
|
||
|
||
Returns:
|
||
The commit SHA of the specified branch.
|
||
|
||
Raises:
|
||
requests.HTTPError: If the API request fails.
|
||
"""
|
||
api_url = f'{self.ROOT_URL}/repos/{owner}/{repo}/branches/{branch_name}'
|
||
response = self.session.get(api_url)
|
||
response.raise_for_status()
|
||
branch_data = response.json()
|
||
return branch_data['commit']['sha']
|
||
|
||
def create_branch(
|
||
self,
|
||
owner: str,
|
||
repo: str,
|
||
new_branch_name: str,
|
||
commit_sha: str,
|
||
) -> bool:
|
||
"""Create a new branch from the provided SHA.
|
||
|
||
Args:
|
||
owner: Owner of the repository.
|
||
repo: Name of the repository.
|
||
new_branch_name: Name of the new branch to create.
|
||
commit_sha: Commit SHA to use as the starting point for the new branch.
|
||
|
||
Returns:
|
||
True if the branch was created successfully, False if the branch already exists.
|
||
|
||
Raises:
|
||
requests.HTTPError: If the API request fails for reasons other than branch already existing.
|
||
"""
|
||
api_url = f'{self.ROOT_URL}/repos/{owner}/{repo}/git/refs'
|
||
request_payload = {'ref': f'refs/heads/{new_branch_name}', 'sha': commit_sha}
|
||
response = self.session.post(api_url, json=request_payload)
|
||
if response.status_code == UNPROCESSABLE_ENTITY_STATUS_CODE:
|
||
logger.warning('Branch %s already exists.', new_branch_name)
|
||
return False
|
||
response.raise_for_status()
|
||
return True
|
||
|
||
def get_issues(self, owner: str, repo: str) -> list[GiteaIssue]:
|
||
"""Download issues from the specified repository and filter those with the 'aider' label.
|
||
|
||
Args:
|
||
owner: Owner of the repository.
|
||
repo: Name of the repository.
|
||
|
||
Returns:
|
||
A list of GiteaIssue objects, filtered to only include issues with the 'aider' label.
|
||
|
||
Raises:
|
||
requests.HTTPError: If the API request fails.
|
||
"""
|
||
api_url = f'{self.ROOT_URL}/repos/{owner}/{repo}/issues'
|
||
response = self.session.get(api_url)
|
||
response.raise_for_status()
|
||
issues_data = response.json()
|
||
|
||
# Filter to only include issues marked with the "aider" label.
|
||
filtered_issues = [
|
||
issue_data
|
||
for issue_data in issues_data
|
||
if any(
|
||
label_data.get('name') == AIDER_LABEL_NAME
|
||
for label_data in issue_data.get('labels', [])
|
||
)
|
||
]
|
||
|
||
# Convert to dataclass objects
|
||
gitea_issues = []
|
||
for issue_data in filtered_issues:
|
||
labels = [
|
||
GiteaLabel(
|
||
id=label_data['id'],
|
||
name=label_data['name'],
|
||
color=label_data['color'],
|
||
description=label_data.get('description', ''),
|
||
)
|
||
for label_data in issue_data.get('labels', [])
|
||
]
|
||
|
||
user = GiteaUser(
|
||
login=issue_data['user']['login'],
|
||
id=issue_data['user']['id'],
|
||
full_name=issue_data['user'].get('full_name', ''),
|
||
email=issue_data['user'].get('email', ''),
|
||
avatar_url=issue_data['user'].get('avatar_url', ''),
|
||
)
|
||
|
||
assignees = [
|
||
GiteaUser(
|
||
login=assignee_data['login'],
|
||
id=assignee_data['id'],
|
||
full_name=assignee_data.get('full_name', ''),
|
||
email=assignee_data.get('email', ''),
|
||
avatar_url=assignee_data.get('avatar_url', ''),
|
||
)
|
||
for assignee_data in issue_data.get('assignees', [])
|
||
]
|
||
|
||
gitea_issue = GiteaIssue(
|
||
id=issue_data['id'],
|
||
number=issue_data['number'],
|
||
title=issue_data['title'],
|
||
body=issue_data.get('body', ''),
|
||
state=issue_data['state'],
|
||
labels=labels,
|
||
user=user,
|
||
assignees=assignees,
|
||
html_url=issue_data['html_url'],
|
||
created_at=issue_data['created_at'],
|
||
updated_at=issue_data['updated_at'],
|
||
)
|
||
gitea_issues.append(gitea_issue)
|
||
|
||
return gitea_issues
|
||
|
||
def iter_user_repositories(
|
||
self,
|
||
owner_name: str,
|
||
only_those_with_issues: bool = False,
|
||
) -> Iterator[str]:
|
||
"""Get a list of repositories for a given user.
|
||
|
||
Args:
|
||
owner_name: The owner of the repositories.
|
||
only_those_with_issues: If True, only return repositories with issues enabled.
|
||
|
||
Returns:
|
||
An iterator of repository names.
|
||
"""
|
||
api_url = f'{self.ROOT_URL}/user/repos'
|
||
response = self.session.get(api_url)
|
||
response.raise_for_status()
|
||
|
||
for repository_data in response.json():
|
||
if only_those_with_issues and not repository_data['has_issues']:
|
||
continue
|
||
if repository_data['owner']['login'].lower() != owner_name.lower():
|
||
continue
|
||
yield repository_data['name']
|
||
|
||
def create_pull_request(
|
||
self,
|
||
owner: str,
|
||
repo: str,
|
||
title: str,
|
||
body: str,
|
||
head: str,
|
||
base: str,
|
||
labels: list[str] = None,
|
||
) -> dict:
|
||
"""Create a pull request and optionally apply labels.
|
||
|
||
Args:
|
||
owner (str): Owner of the repository.
|
||
repo (str): Name of the repository.
|
||
title (str): Title of the pull request.
|
||
body (str): Description/body of the pull request.
|
||
head (str): The name of the branch where changes are implemented.
|
||
base (str): The name of the branch you want the changes pulled into.
|
||
labels (list[str], optional): List of label names to apply to the pull request.
|
||
|
||
Returns:
|
||
dict: The created pull request data.
|
||
|
||
Raises:
|
||
requests.HTTPError: If the API request fails.
|
||
"""
|
||
api_url = f'{self.ROOT_URL}/repos/{owner}/{repo}/pulls'
|
||
request_payload = {
|
||
'title': title,
|
||
'body': body,
|
||
'head': head,
|
||
'base': base,
|
||
}
|
||
|
||
response = self.session.post(api_url, json=request_payload)
|
||
# If a pull request for this head/base already exists, return it instead of crashing
|
||
if response.status_code == CONFLICT_STATUS_CODE:
|
||
logger.warning(
|
||
'Pull request already exists for head %s and base %s',
|
||
head,
|
||
base,
|
||
)
|
||
existing_pull_requests = self.get_pull_requests(owner, repo)
|
||
for existing_pr in existing_pull_requests:
|
||
if (
|
||
existing_pr.get('head', {}).get('ref') == head
|
||
and existing_pr.get('base', {}).get('ref') == base
|
||
):
|
||
return existing_pr
|
||
# fallback to raise if we can’t find it
|
||
response.raise_for_status()
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
def get_failed_pipelines(
|
||
self,
|
||
owner: str,
|
||
repo: str,
|
||
pull_request_number: str,
|
||
) -> list[int]:
|
||
"""Fetch pipeline runs for a PR and return IDs of failed runs."""
|
||
api_url = f'{self.ROOT_URL}/repos/{owner}/{repo}/actions/runs'
|
||
response = self.session.get(api_url)
|
||
response.raise_for_status()
|
||
workflow_runs = response.json().get('workflow_runs', [])
|
||
failed_run_ids = []
|
||
for workflow_run in workflow_runs:
|
||
if any(
|
||
pull_request.get('number') == int(pull_request_number)
|
||
for pull_request in workflow_run.get('pull_requests', [])
|
||
):
|
||
if workflow_run.get('conclusion') != SUCCESS_CONCLUSION:
|
||
failed_run_ids.append(workflow_run.get('id'))
|
||
return failed_run_ids
|
||
|
||
def get_pipeline_log(self, owner: str, repo: str, workflow_run_id: int) -> str:
|
||
"""Download the logs for a pipeline run."""
|
||
api_url = (
|
||
f'{self.ROOT_URL}/repos/{owner}/{repo}/actions/runs/{workflow_run_id}/logs'
|
||
)
|
||
response = self.session.get(api_url)
|
||
response.raise_for_status()
|
||
return response.text
|
||
|
||
def get_pull_requests(
|
||
self,
|
||
owner: str,
|
||
repo: str,
|
||
pull_request_state: str = 'open',
|
||
) -> list[dict[str, Any]]:
|
||
"""Fetch pull requests for a repository."""
|
||
api_url = (
|
||
f'{self.ROOT_URL}/repos/{owner}/{repo}/pulls?state={pull_request_state}'
|
||
)
|
||
response = self.session.get(api_url)
|
||
response.raise_for_status()
|
||
return response.json()
|