Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AI-generated pull request feature #217

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import aicommits from './commands/aicommits.js';
import prepareCommitMessageHook from './commands/prepare-commit-msg-hook.js';
import configCommand from './commands/config.js';
import hookCommand, { isCalledFromGitHook } from './commands/hook.js';
import prCommand from './commands/aipr.js';

const rawArgv = process.argv.slice(2);

Expand Down Expand Up @@ -45,6 +46,7 @@ cli(
commands: [
configCommand,
hookCommand,
prCommand,
],

help: {
Expand Down
156 changes: 156 additions & 0 deletions src/commands/aipr.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import { command } from 'cleye';
import { execa } from 'execa';
import {
black,
dim,
green,
red,
bgCyan,
} from 'kolorist';
import {
intro,
outro,
spinner,
select,
confirm,
isCancel,
} from '@clack/prompts';
import {
assertGitRepo,
getDetectedMessage,
getStagedDiffFromTrunk,
} from '../utils/git.js';
import { getConfig } from '../utils/config.js';
import { generatePullRequest } from '../utils/openai.js';
import { KnownError, handleCliError } from '../utils/error.js';

export default command(
{
name: 'pr',
/**
* Since this is a wrapper around `gh pr create`,
* flags should not overlap with it
* https://cli.github.com/manual/gh_pr_create
*/
flags: {
generate: {
type: Number,
description: 'Number of messages to generate (Warning: generating multiple costs more) (default: 1)',
alias: 'g',
},
trunkBranch: {
type: String,
description: 'The branch into which you want your code merged',
alias: 'B',
},
exclude: {
type: [String],
description: 'Files to exclude from AI analysis',
alias: 'x',
},
all: {
type: Boolean,
description: 'Automatically stage changes in tracked files for the commit',
alias: 'A',
default: false,
},
},
},
(argv) => {
(async () => {
const {
all: stageAll,
exclude: excludeFiles,
trunkBranch: trunk,
generate,
} = argv.flags;

intro(bgCyan(black(' aipr ')));
await assertGitRepo();

const detectingFiles = spinner();

if (stageAll) {
// This should be equivalent behavior to `git commit --all`
await execa('git', ['add', '--update']);
}

detectingFiles.start('Detecting staged files');
const staged = await getStagedDiffFromTrunk(trunk, excludeFiles);

if (!staged) {
detectingFiles.stop('Detecting staged files');
throw new KnownError(
'No staged changes found. Stage your changes manually, or automatically stage all changes with the `--all` flag.',
);
}

detectingFiles.stop(`${getDetectedMessage(staged.files)}:\n${staged.files.map(file => ` ${file}`).join('\n')}`);

const { env } = process;
const config = await getConfig({
OPENAI_KEY: env.OPENAI_KEY || env.OPENAI_API_KEY,
proxy:
env.https_proxy
|| env.HTTPS_PROXY
|| env.http_proxy
|| env.HTTP_PROXY,
generate: generate?.toString(),
});

const s = spinner();
s.start('The AI is analyzing your changes');
let messages: string[];
try {
messages = await generatePullRequest(
config.OPENAI_KEY,
config.model,
config.locale,
staged.diff,
config.generate,
config.timeout,
config.proxy,
);
} finally {
s.stop('Changes analyzed');
}

if (messages.length === 0) {
throw new KnownError('A PR was not generated. Try again.');
}

let message: string;
if (messages.length === 1) {
[message] = messages;
const confirmed = await confirm({
message: `Use this PR?\n\n ${message}\n`,
});

if (!confirmed || isCancel(confirmed)) {
outro('PR cancelled');
return;
}
} else {
const selected = await select({
message: `Pick a PR to use: ${dim('(Ctrl+c to exit)')}`,
options: messages.map(value => ({ label: value, value })),
});

if (isCancel(selected)) {
outro('PR cancelled');
return;
}

message = selected;
}

await execa('gh', ['pr', 'create', '-b', `"${message}"`, '-t', `"${message.split('\n')[0]}"`, '-B', trunk ?? 'main']);

outro(`${green('✔')} Successfully created!`);
})().catch((error) => {
console.error(`${red('✖')} ${error.message}`);
handleCliError(error);
process.exit(1);
});
},
);
43 changes: 43 additions & 0 deletions src/utils/git.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,47 @@ export const getStagedDiff = async (excludeFiles?: string[]) => {
};
};

export const getStagedDiffFromTrunk = async (trunkBranch?: string, excludeFiles?: string[]) => {
const trunk = trunkBranch ?? 'main';
const branchesToCompare = `${trunk}..HEAD`;

const { stdout: files } = await execa(
'git',
[
'diff',
branchesToCompare,
'--name-only',
...filesToExclude,
...(
excludeFiles
? excludeFiles.map(excludeFromDiff)
: []
),
],
);

if (!files) {
return;
}

const { stdout: diff } = await execa(
'git',
[
'diff',
branchesToCompare,
...filesToExclude,
...(
excludeFiles
? excludeFiles.map(excludeFromDiff)
: []
),
],
);

return {
files: files.split('\n'),
diff,
};
};

export const getDetectedMessage = (files: string[]) => `Detected ${files.length.toLocaleString()} staged file${files.length > 1 ? 's' : ''}`;
55 changes: 53 additions & 2 deletions src/utils/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
import createHttpsProxyAgent from 'https-proxy-agent';
import { KnownError } from './error.js';
import type { CommitType } from './config.js';
import { generatePrompt } from './prompt.js';
import { generateCommitPrompt, generatePullRequestPrompt } from './prompt.js';

const httpsPost = async (
hostname: string,
Expand Down Expand Up @@ -141,7 +141,7 @@ export const generateCommitMessage = async (
messages: [
{
role: 'system',
content: generatePrompt(locale, maxLength, type),
content: generateCommitPrompt(locale, maxLength, type),
},
{
role: 'user',
Expand Down Expand Up @@ -174,3 +174,54 @@ export const generateCommitMessage = async (
throw errorAsAny;
}
};

export const generatePullRequest = async (
apiKey: string,
model: TiktokenModel,
locale: string,
diff: string,
completions: number,
timeout: number,
proxy?: string,
) => {
try {
const completion = await createChatCompletion(
apiKey,
{
model,
messages: [
{
role: 'system',
content: generatePullRequestPrompt(locale),
},
{
role: 'user',
content: diff,
},
],
temperature: 0.7,
top_p: 1,
frequency_penalty: 0,
presence_penalty: 0,
max_tokens: 200,
stream: false,
n: completions,
},
timeout,
proxy,
);

return deduplicateMessages(
completion.choices
.filter(choice => choice.message?.content)
.map(choice => choice.message!.content),
);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
throw new KnownError(`Error connecting to ${errorAsAny.hostname} (${errorAsAny.syscall}). Are you connected to the internet?`);
}

throw errorAsAny;
}
};
10 changes: 9 additions & 1 deletion src/utils/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const commitTypes: Record<CommitType, string> = {
}`,
};

export const generatePrompt = (
export const generateCommitPrompt = (
locale: string,
maxLength: number,
type: CommitType,
Expand All @@ -46,3 +46,11 @@ export const generatePrompt = (
commitTypes[type],
specifyCommitFormat(type),
].filter(Boolean).join('\n');

export const generatePullRequestPrompt = (
locale: string,
) => [
'Generate the content for a descriptive Github Pull Request written in present tense for the following code diff with the given specifications below:',
`Message language: ${locale}`,
'Exclude anything unnecessary such as translation. Use Markdown syntax if you want to emphasize certain changes, list out lists of changes, or separate the content with headings and newlines. Assume your audience is a team member who has no context on what you\'re working on. Your entire response will be passed directly into gh pr create.',
].filter(Boolean).join('\n');