Skip to content

Commit

Permalink
Store models meta information in a CSV file, so it can be further use…
Browse files Browse the repository at this point in the history
…d in data visualization

Part of #296
  • Loading branch information
ruiAzevedo19 committed Aug 1, 2024
1 parent a31b419 commit 37f2824
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 1 deletion.
30 changes: 29 additions & 1 deletion cmd/eval-dev-quality/cmd/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
"github.com/symflower/eval-dev-quality/evaluate"
"github.com/symflower/eval-dev-quality/evaluate/report"
"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/util"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
)

// Report holds the "report" command.
Expand Down Expand Up @@ -80,6 +81,33 @@ func (command *Report) Execute(args []string) (err error) {
command.logger.Panicf("ERROR: %s", err)
}

// Create a CSV file that holds the models meta information.
var modelsMetaInformationCSVFile *os.File
if modelsMetaInformationCSVFile, err = os.OpenFile(filepath.Join(command.ResultPath, "meta.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755); err != nil {
command.logger.Panicf("ERROR: %s", err)
}
defer modelsMetaInformationCSVFile.Close()

// Fetch all models meta information.
var modelsMetaInformation []*model.MetaInformation
for _, provider := range provider.Providers {
models, err := provider.Models()
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
for _, model := range models {
if modelMetaInformation := model.MetaInformation(); modelMetaInformation != nil {
modelsMetaInformation = append(modelsMetaInformation, model.MetaInformation())
}
}
}
metaInformationRecords := report.MetaInformationRecords(modelsMetaInformation)
report.SortEvaluationRecords(metaInformationRecords)
// Write models meta information to disk.
if err := report.WriteMetaInformationRecords(modelsMetaInformationCSVFile, metaInformationRecords); err != nil {
command.logger.Panicf("ERROR: %s", err)
}

// Write markdown reports.
assessmentsPerModel, err := report.RecordsToAssessmentsPerModel(records)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions cmd/eval-dev-quality/cmd/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func TestReportExecute(t *testing.T) {
expectedContent := fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
filepath.Join("result-directory", "meta.csv"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -213,6 +214,7 @@ func TestReportExecute(t *testing.T) {
expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
filepath.Join("result-directory", "meta.csv"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -253,6 +255,7 @@ func TestReportExecute(t *testing.T) {
expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
filepath.Join("result-directory", "meta.csv"): nil,
},
})
}
Expand Down
34 changes: 34 additions & 0 deletions evaluate/report/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,40 @@ func assessmentFromRecord(assessmentFields []string) (assessments metrics.Assess
return assessments, nil
}

// MetaInformationRecords converts the models meta information into CSV records.
func MetaInformationRecords(modelsMetaInformation []*model.MetaInformation) (records [][]string) {
records = [][]string{}

for _, metaInformation := range modelsMetaInformation {
records = append(records, []string{
metaInformation.ID,
metaInformation.Name,
strconv.FormatFloat(metaInformation.Pricing.Completion, 'f', -1, 64),
strconv.FormatFloat(metaInformation.Pricing.Image, 'f', -1, 64),
strconv.FormatFloat(metaInformation.Pricing.Prompt, 'f', -1, 64),
strconv.FormatFloat(metaInformation.Pricing.Request, 'f', -1, 64),
})
}

return records
}

// WriteMetaInformationRecords writes the meta information records into a CSV file.
func WriteMetaInformationRecords(writer io.Writer, records [][]string) (err error) {
csv := csv.NewWriter(writer)

header := []string{"model-id", "model-name", "completion", "image", "prompt", "request"}
if err := csv.Write(header); err != nil {
return pkgerrors.WithStack(err)
}
if err := csv.WriteAll(records); err != nil {
return pkgerrors.WithStack(err)
}
csv.Flush()

return nil
}

// SortEvaluationRecords sorts the evaluation records.
func SortEvaluationRecords(records [][]string) {
sort.Slice(records, func(i, j int) bool {
Expand Down
86 changes: 86 additions & 0 deletions evaluate/report/csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/symflower/eval-dev-quality/evaluate/metrics"
evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task"
languagetesting "github.com/symflower/eval-dev-quality/language/testing"
"github.com/symflower/eval-dev-quality/model"
modeltesting "github.com/symflower/eval-dev-quality/model/testing"
"github.com/symflower/eval-dev-quality/task"
)
Expand Down Expand Up @@ -480,3 +481,88 @@ func TestRecordsToAssessmentsPerModel(t *testing.T) {
},
})
}

func TestWriteMetaInformationRecords(t *testing.T) {
var file strings.Builder

err := WriteMetaInformationRecords(&file, [][]string{
[]string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"},
[]string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"},
[]string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"},
[]string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"},
[]string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"},
})
require.NoError(t, err)

assert.Equal(t, bytesutil.StringTrimIndentations(`
model-id,model-name,completion,image,prompt,request
provider/modelA,modelA,0.1,0.2,0.3,0.4
provider/modelB,modelB,0.01,0.02,0.03,0.04
provider/modelC,modelC,0.001,0.002,0.003,0.004
provider/modelD,modelD,0.0001,0.0002,0.0003,0.0004
provider/modelE,modelE,0.00001,0.00002,0.00003,0.00004
`), file.String())
}

func TestMetaInformationRecords(t *testing.T) {
actualRecords := MetaInformationRecords([]*model.MetaInformation{
&model.MetaInformation{
ID: "provider/modelA",
Name: "modelA",
Pricing: model.Pricing{
Completion: 0.1,
Image: 0.2,
Prompt: 0.3,
Request: 0.4,
},
},
&model.MetaInformation{
ID: "provider/modelB",
Name: "modelB",
Pricing: model.Pricing{
Completion: 0.01,
Image: 0.02,
Prompt: 0.03,
Request: 0.04,
},
},
&model.MetaInformation{
ID: "provider/modelC",
Name: "modelC",
Pricing: model.Pricing{
Completion: 0.001,
Image: 0.002,
Prompt: 0.003,
Request: 0.004,
},
},
&model.MetaInformation{
ID: "provider/modelD",
Name: "modelD",
Pricing: model.Pricing{
Completion: 0.0001,
Image: 0.0002,
Prompt: 0.0003,
Request: 0.0004,
},
},
&model.MetaInformation{
ID: "provider/modelE",
Name: "modelE",
Pricing: model.Pricing{
Completion: 0.00001,
Image: 0.00002,
Prompt: 0.00003,
Request: 0.00004,
},
},
})

assert.ElementsMatch(t, [][]string{
[]string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"},
[]string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"},
[]string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"},
[]string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"},
[]string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"},
}, actualRecords)
}

0 comments on commit 37f2824

Please sign in to comment.