Skip to content

Commit

Permalink
Let the LLM models have meta information such as pricing and human-re…
Browse files Browse the repository at this point in the history
…adable names

Part of #296
  • Loading branch information
ruiAzevedo19 committed Jul 31, 2024
1 parent b1c89fb commit cf3c0f1
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 12 deletions.
4 changes: 2 additions & 2 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
if command.Runtime != "local" {
// Copy the models over.
for _, modelID := range command.Models {
evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID))
evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID, nil))
}

return evaluationContext, evaluationConfiguration, func() {}
Expand Down Expand Up @@ -262,7 +262,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
modelProvider.AddModel(llm.NewModel(modelProvider, model, nil))
}
}

Expand Down
6 changes: 3 additions & 3 deletions evaluate/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func TestEvaluate(t *testing.T) {
languageGolang := &golang.Language{}
mockedModelID := "testing-provider/empty-response-model"
mockedQuery := providertesting.NewMockQuery(t)
mockedModel := llm.NewModel(mockedQuery, mockedModelID)
mockedModel := llm.NewModel(mockedQuery, mockedModelID, nil)
repositoryPath := filepath.Join("golang", "plain")

validate(t, &testCase{
Expand Down Expand Up @@ -290,7 +290,7 @@ func TestEvaluate(t *testing.T) {
languageGolang := &golang.Language{}
mockedModelID := "testing-provider/empty-response-model"
mockedQuery := providertesting.NewMockQuery(t)
mockedModel := llm.NewModel(mockedQuery, mockedModelID)
mockedModel := llm.NewModel(mockedQuery, mockedModelID, nil)
repositoryPath := filepath.Join("golang", "plain")

validate(t, &testCase{
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestEvaluate(t *testing.T) {
languageGolang := &golang.Language{}
mockedModelID := "testing-provider/empty-response-model"
mockedQuery := providertesting.NewMockQuery(t)
mockedModel := llm.NewModel(mockedQuery, mockedModelID)
mockedModel := llm.NewModel(mockedQuery, mockedModelID, nil)
repositoryPath := filepath.Join("golang", "plain")

validate(t, &testCase{
Expand Down
12 changes: 11 additions & 1 deletion model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,28 @@ type Model struct {

// queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
queryAttempts uint

// metaInformation holds a model meta information.
metaInformation *model.MetaInformation
}

// NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider.
func NewModel(provider provider.Query, modelIdentifier string) *Model {
func NewModel(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model {
return &Model{
provider: provider,
model: modelIdentifier,

queryAttempts: 1,

metaInformation: metaInformation,
}
}

// MetaInformation returns the meta information of a model.
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
return m.metaInformation
}

// llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt.
type llmSourceFilePromptContext struct {
// Language holds the programming language name.
Expand Down
6 changes: 3 additions & 3 deletions model/llm/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestModelGenerateTestsForFile(t *testing.T) {

mock := providertesting.NewMockQuery(t)
tc.SetupMock(mock)
llm := NewModel(mock, tc.ModelID)
llm := NewModel(mock, tc.ModelID, nil)

ctx := model.Context{
Language: tc.Language,
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestModelRepairSourceCodeFile(t *testing.T) {
modelID := "some-model"
mock := providertesting.NewMockQuery(t)
tc.SetupMock(t, mock)
llm := NewModel(mock, modelID)
llm := NewModel(mock, modelID, nil)

ctx := model.Context{
Language: tc.Language,
Expand Down Expand Up @@ -496,7 +496,7 @@ func TestModelTranspile(t *testing.T) {
modelID := "some-model"
mock := providertesting.NewMockQuery(t)
tc.SetupMock(t, mock)
llm := NewModel(mock, modelID)
llm := NewModel(mock, modelID, nil)

ctx := model.Context{
Language: tc.Language,
Expand Down
3 changes: 3 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
type Model interface {
// ID returns the unique ID of this model.
ID() (id string)

// MetaInformation returns the meta information of a model.
MetaInformation() *MetaInformation
}

// MetaInformation holds a model.
Expand Down
5 changes: 5 additions & 0 deletions model/symflower/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ func (m *Model) ID() (id string) {
return "symflower" + provider.ProviderModelSeparator + "symbolic-execution"
}

// MetaInformation returns the meta information of a model.
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
return nil
}

var _ model.CapabilityWriteTests = (*Model)(nil)

// generateTestsForFile generates test files for the given implementation file in a repository.
Expand Down
25 changes: 24 additions & 1 deletion model/testing/Model_mock_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion provider/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (p *Provider) Models() (models []model.Model, err error) {

models = make([]model.Model, len(ms))
for i, modelName := range ms {
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+modelName)
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+modelName, nil)
}

return models, nil
Expand Down
2 changes: 1 addition & 1 deletion provider/openrouter/openrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (p *Provider) Models() (models []model.Model, err error) {

models = make([]model.Model, len(responseModels.Models))
for i, model := range responseModels.Models {
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID)
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID, &model)
}

return models, nil
Expand Down

0 comments on commit cf3c0f1

Please sign in to comment.