From 13127851336ae15892748179f7a3f4170adb0165 Mon Sep 17 00:00:00 2001 From: sugarspectre Date: Wed, 9 Feb 2022 17:48:08 +0800 Subject: [PATCH 1/2] classification metrics --- src/linalg/utils.ts | 20 ++- src/metrics/classifier.ts | 265 ++++++++++++++++++++++++++-- src/metrics/index.ts | 20 ++- src/preprocess/encoder.ts | 95 +++++++--- src/preprocess/index.ts | 2 +- test/node/metrics/classification.ts | 72 ++++++++ 6 files changed, 433 insertions(+), 41 deletions(-) create mode 100644 test/node/metrics/classification.ts diff --git a/src/linalg/utils.ts b/src/linalg/utils.ts index bdcf7837..6583aaae 100644 --- a/src/linalg/utils.ts +++ b/src/linalg/utils.ts @@ -1,4 +1,4 @@ -import { Tensor, RecursiveArray, norm, div, max, sub, abs, lessEqual, slice, tensor, isNaN, where, tidy } from '@tensorflow/tfjs-core'; +import { Tensor, RecursiveArray, norm, div, max, sub, abs, lessEqual, slice, tensor, isNaN, where, tidy, stack, squeeze } from '@tensorflow/tfjs-core'; import { checkArray } from '../utils/validation'; /** @@ -114,3 +114,21 @@ export const fillNaN = (xData: Tensor | RecursiveArray, fillV = 0): Tens return where(cond, fillV, xTensor); }); }; + +/** + * Get the digonal elements in a matrix + * @param matrix target matrix + * @returns tensor of diagnal elements + */ +export const getDiagElements = (matrix: Tensor | number[]): Tensor => { + return tidy(() => { + const matrixTensor = checkArray(matrix, 'any', 2); + const [ m, n ] = matrixTensor.shape; + const rank = m > n ? n : m; + const diagElements = []; + for (let i = 0; i < rank; i++) { + diagElements.push(slice(matrixTensor, [ i, i ], [ 1, 1 ])); + } + return squeeze(stack(diagElements)); + }); +}; diff --git a/src/metrics/classifier.ts b/src/metrics/classifier.ts index f4c8cea2..7806c036 100644 --- a/src/metrics/classifier.ts +++ b/src/metrics/classifier.ts @@ -1,15 +1,258 @@ -import { Tensor, equal, sum, div } from '@tensorflow/tfjs-core'; -import { checkArray } from '../utils/validation'; +import { Tensor, Tensor1D, equal, sum, div, math, divNoNan, concat, mul, add, cast } from '@tensorflow/tfjs-core'; +import { checkSameLength } from '../utils/validation'; +import { getDiagElements } from '../linalg/utils'; +import { LabelEncoder } from '../preprocess'; -export const accuracyScore = ( yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[] ): number => { - const yTrueTensor = checkArray(yTrue, 'any', 1); - const yPredTensor = checkArray(yPred, 'any', 1); - const yTrueCount = yTrueTensor.shape[0]; - const yPredCount = yPredTensor.shape[0]; - if (yTrueCount != yPredCount) { - throw new Error('Shape of yTrue should match shape of yPred'); - } +export type ClassificationReport = { + precisions: Tensor; + recalls: Tensor; + f1s: Tensor; + confusionMatrix: Tensor; + categories: Tensor; + accuracy: number; + averagePrecision: number; + averageRecall: number; + averageF1: number; +}; + +export type ClassificationAverageTypes = 'macro' | 'weighted' | 'micro' + +export const accuracyScore = (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): number => { + const [ yTrueTensor, yPredTensor, nLabels ] = checkSameLength(yTrue, yPred); // TODO(sugarspectre): Accuaracy score computation - const score = div(sum(equal(yPred, yTrue)), yTrueCount).dataSync()[0]; + const score = div(sum(equal(yPredTensor, yTrueTensor)), nLabels).dataSync()[0]; return score; }; + +export const getConfusionMatrix = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): Promise<{ confusionMatrix: Tensor, categories: Tensor }> => { + const [ yTrueTensor, yPredTensor ] = checkSameLength(yTrue, yPred); + const labelEncoder = new LabelEncoder(); + await labelEncoder.init(concat([ yTrueTensor, yPredTensor ])); + const yTrueEncode = await labelEncoder.encode(yTrueTensor); + const yPredEncode = await labelEncoder.encode(yPredTensor); + const numClasses = labelEncoder.categories.shape[0]; + const confusionMatrix = cast(math.confusionMatrix(yTrueEncode as Tensor1D, yPredEncode as Tensor1D, numClasses), 'float32'); + return { confusionMatrix, categories: labelEncoder.categories }; +}; + +const getPrecisionScoreByConfusionMatrix = (confusionMatrix: Tensor, average: ClassificationAverageTypes = 'micro'): number => { + const confusionDiag = getDiagElements(confusionMatrix); + const numClasses = confusionMatrix.shape[0]; + const precisions = divNoNan(confusionDiag, sum(confusionMatrix, 0)); + const weights = divNoNan(sum(confusionMatrix, 0), sum(confusionMatrix)); + const weightsSupport = divNoNan(sum(confusionMatrix, 1), sum(confusionMatrix)); + switch (average) { + case 'micro': + return sum(mul(precisions, weights)).dataSync()[0]; + case 'macro': + return divNoNan(sum(precisions), numClasses).dataSync()[0]; + case 'weighted': + return sum(mul(precisions, weightsSupport)).dataSync()[0]; + default: + return sum(mul(precisions, weights)).dataSync()[0]; + } +}; + +const getRecallScoreByConfusionMatrix = (confusionMatrix: Tensor, average: ClassificationAverageTypes = 'micro'): number => { + const confusionDiag = getDiagElements(confusionMatrix); + const numClasses = confusionMatrix.shape[0]; + const recalls = divNoNan(confusionDiag, sum(confusionMatrix, 1)); + const weights = divNoNan(sum(confusionMatrix, 1), sum(confusionMatrix)); + switch (average) { + case 'micro': + return sum(mul(recalls, weights)).dataSync()[0]; + case 'macro': + return divNoNan(sum(recalls), numClasses).dataSync()[0]; + case 'weighted': + return sum(mul(recalls, weights)).dataSync()[0]; + default: + return sum(mul(recalls, weights)).dataSync()[0]; + } +}; + +const getF1ScoreByConfusionMatrix = (confusionMatrix: Tensor, average: ClassificationAverageTypes = 'micro'): number => { + const confusionDiag = getDiagElements(confusionMatrix); + const numClasses = confusionMatrix.shape[0]; + const precisions = divNoNan(confusionDiag, sum(confusionMatrix, 0)); + const recalls = divNoNan(confusionDiag, sum(confusionMatrix, 1)); + const f1s = divNoNan(mul(mul(2, precisions), recalls), add(precisions, recalls)); + const weights = divNoNan(sum(confusionMatrix, 0), sum(confusionMatrix)); + const weightsSupport = divNoNan(sum(confusionMatrix, 1), sum(confusionMatrix)); + switch (average) { + case 'micro': { + const precision = sum(mul(precisions, weights)).dataSync()[0]; + const recall = sum(mul(recalls, weightsSupport)).dataSync()[0]; + return (divNoNan(mul(mul(2, precision), recall), add(precision, recall))).dataSync()[0]; + } + case 'macro': + return divNoNan(sum(f1s), numClasses).dataSync()[0]; + case 'weighted': + return sum(mul(f1s, weightsSupport)).dataSync()[0]; + default: + return sum(mul(f1s, weights)).dataSync()[0]; + } +}; + +/** + * Compute the precision score for all classes. + * Precision score is the ratio tp / (tp + fp), where tp is the number of true positives + * and fp the number of false positives. The precision is intuitively the ability of + * the classifier not to label as positive a sample that is negative. + * The best value is 1 and the worst value is 0. + * @param yTrue Ground truth (correct) target values. + * @param yPred Estimated targets as returned by a classifier. + * @returns Tensor of precision scores + */ +export const getPrecisionScores = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): Promise<{ precisions: Tensor, categories: Tensor }> => { + const { confusionMatrix, categories } = await getConfusionMatrix(yTrue, yPred); + const confusionDiag = getDiagElements(confusionMatrix); + const precisions = divNoNan(confusionDiag, sum(confusionMatrix, 0)); + return { precisions: precisions, categories }; +}; + +/** + * Compute the recall score. + * Recall score is the ratio tp / (tp + fn), where tp is the number of true positives + * and fn the number of false negtive. The recall is intuitively the ability of the + * classifier to find all the positive samples. + * The best value is 1 and the worst value is 0. + * @param yTrue Ground truth (correct) target values. + * @param yPred Estimated targets as returned by a classifier. + * @returns Tensor of recall scores + */ +export const getRecallScores = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): Promise<{ recalls: Tensor, categories: Tensor }> => { + const { confusionMatrix, categories } = await getConfusionMatrix(yTrue, yPred); + const confusionDiag = getDiagElements(confusionMatrix); + return { recalls: divNoNan(confusionDiag, sum(confusionMatrix, 1)), categories: categories }; +}; + +/** + * Compute the f1 score. + * The F1 score can be interpreted as a harmonic mean of the precision and recall, + * where an F1 score reaches its best value at 1 and worst score at 0. The relative + * contribution of precision and recall to the F1 score are equal. The formula for the F1 score is: + * `2 * precision * recall / (precision + recall)` + * @param yTrue Ground truth (correct) target values. + * @param yPred Estimated targets as returned by a classifier. + * @returns Tensor of f1 scores + */ +export const getF1Scores = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[]): Promise<{ f1s: Tensor, categories: Tensor }> => { + const { confusionMatrix, categories } = await getConfusionMatrix(yTrue, yPred); + const confusionDiag = getDiagElements(confusionMatrix); + const precisions = divNoNan(confusionDiag, sum(confusionMatrix, 0)); + const recalls = divNoNan(confusionDiag, sum(confusionMatrix, 1)); + const f1s = divNoNan(mul(mul(2, precisions), recalls), add(precisions, recalls)); + return { f1s, categories }; +}; + +/** + * Compute the precision score. + * Precision score is the ratio tp / (tp + fp), where tp is the number of true positives + * and fp the number of false positives. The precision is intuitively the ability of + * the classifier not to label as positive a sample that is negative. + * The best value is 1 and the worst value is 0. + * @param yTrue Ground truth (correct) target values. + * @param yPred Estimated targets as returned by a classifier. + * @param average \{'micro', 'macro', 'weighted'\} + * This parameter is required for multiclass/multilabel targets. This determines the type of averaging + * performed on the data, **default='micro'**: + * - `'micro'`: + * Calculate metrics globally by counting the total true positives, false negatives and false positives. + * - `'macro'`: + * Calculate metrics for each label, and find their unweighted mean. This does not take label + * imbalance into account. + * - `'weighted'`: + * Calculate metrics for each label, and find their average weighted by support (the number of true + * instances for each label). This alters ‘macro’ to account for label imbalance; it can result in an + * F-score that is not between precision and recall. + * @returns Tensor of precision scores + */ +export const getPrecisionScore = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'micro'): Promise => { + const { confusionMatrix } = await getConfusionMatrix(yTrue, yPred); + return getPrecisionScoreByConfusionMatrix(confusionMatrix, average); +}; + + +/** + * Compute the recall score. + * Recall score is the ratio tp / (tp + fn), where tp is the number of true positives + * and fn the number of false negtive. The recall is intuitively the ability of the + * classifier to find all the positive samples. + * The best value is 1 and the worst value is 0. + * @param yTrue Ground truth (correct) target values. + * @param yPred Estimated targets as returned by a classifier. + * @param average \{'micro', 'macro', 'weighted'\} + * This parameter is required for multiclass/multilabel targets. This determines the type of averaging + * performed on the data, **default='micro'**: + * - `'micro'`: + * Calculate metrics globally by counting the total true positives, false negatives and false positives. + * - `'macro'`: + * Calculate metrics for each label, and find their unweighted mean. This does not take label + * imbalance into account. + * - `'weighted'`: + * Calculate metrics for each label, and find their average weighted by support (the number of true + * instances for each label). This alters ‘macro’ to account for label imbalance; it can result in an + * F-score that is not between precision and recall. + * @returns precision score + */ +export const getRecallScore = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'micro'): Promise => { + const { confusionMatrix } = await getConfusionMatrix(yTrue, yPred); + return getRecallScoreByConfusionMatrix(confusionMatrix, average); +}; + + +/** + * Compute the f1 score. + * The F1 score can be interpreted as a harmonic mean of the precision and recall, + * where an F1 score reaches its best value at 1 and worst score at 0. The relative + * contribution of precision and recall to the F1 score are equal. The formula for the F1 score is: + * `2 * precision * recall / (precision + recall)` + * @param yTrue Ground truth (correct) target values. + * @param yPred Estimated targets as returned by a classifier. + * @param average \{'micro', 'macro', 'weighted'\} + * This parameter is required for multiclass/multilabel targets. This determines the type of averaging + * performed on the data, **default='micro'**: + * - `'micro'`: + * Calculate metrics globally by counting the total true positives, false negatives and false positives. + * - `'macro'`: + * Calculate metrics for each label, and find their unweighted mean. This does not take label + * imbalance into account. + * - `'weighted'`: + * Calculate metrics for each label, and find their average weighted by support (the number of true + * instances for each label). This alters ‘macro’ to account for label imbalance; it can result in an + * F-score that is not between precision and recall. + * @returns precision score + */ +export const getF1Score = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'micro'): Promise => { + const { confusionMatrix } = await getConfusionMatrix(yTrue, yPred); + return getF1ScoreByConfusionMatrix(confusionMatrix, average); +}; + +/** + * Generate classification report + * @param yTrue true labels + * @param yPred predicted labels + * @returns classification report object, the struct of report will be like following + */ +export const getClassificationReport = async (yTrue: Tensor | string[] | number[], yPred: Tensor | string[] | number[], average: ClassificationAverageTypes = 'weighted'): Promise => { + const { confusionMatrix, categories } = await getConfusionMatrix(yTrue, yPred); + const confusionDiag = getDiagElements(confusionMatrix); + const precisions = divNoNan(confusionDiag, sum(confusionMatrix, 0)); + const recalls = divNoNan(confusionDiag, sum(confusionMatrix, 1)); + const f1s = mul(divNoNan(mul(precisions, recalls), add(precisions, recalls)), 2); + const accuracy = accuracyScore(yTrue, yPred); + const averagePrecision = getPrecisionScoreByConfusionMatrix(confusionMatrix, average); + const averageRecall = getRecallScoreByConfusionMatrix(confusionMatrix, average); + const averageF1 = getF1ScoreByConfusionMatrix(confusionMatrix, average); + return { + precisions: precisions, + recalls: recalls, + f1s: f1s, + confusionMatrix: confusionMatrix, + categories, + accuracy, + averageF1, + averagePrecision, + averageRecall + }; +}; diff --git a/src/metrics/index.ts b/src/metrics/index.ts index 517b0973..9bcc9fe7 100644 --- a/src/metrics/index.ts +++ b/src/metrics/index.ts @@ -1 +1,19 @@ -export { accuracyScore } from './classifier'; +export { + accuracyScore, + getClassificationReport, + getF1Score, + getF1Scores, + getPrecisionScore, + getPrecisionScores, + getRecallScore, + getRecallScores, + getConfusionMatrix +} from './classifier'; + +export { + getRSquare, + getAICLM, + getMeanSquaredError, + getAdjustedRSquare, + getResidualVariance +} from './regression'; diff --git a/src/preprocess/encoder.ts b/src/preprocess/encoder.ts index ef30fb79..3ab7950b 100644 --- a/src/preprocess/encoder.ts +++ b/src/preprocess/encoder.ts @@ -1,4 +1,4 @@ -import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk } from "@tensorflow/tfjs-core"; +import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk, Tensor1D, tidy } from "@tensorflow/tfjs-core"; import { checkArray } from "../utils/validation"; import { checkShape } from "../linalg/utils"; @@ -11,13 +11,40 @@ export type OneHotEncoderParams = { drop: OneHotDropTypes } +export abstract class EncoderBase { + public categories: Tensor; + public cateMap: CateMap; + /** + * Init encoder + * @param x data input used to init encoder + * @param categories user input categories + */ + public async init(x: Tensor | number[] | string[]): Promise { + const { values } = unique(x); + if (values.dtype === 'int32' || values.dtype === 'float32') { + this.categories = topk(values, values.shape[0], false).values; + } else if (values.dtype === 'bool') { + this.categories = tensor([ false, true ]); + } else { + this.categories = values; + } + const cateData = await this.categories.data(); + const cateMap: CateMap = {}; + for (let i = 0; i < cateData.length; i++) { + const key = cateData[i]; + cateMap[key] = i; + } + this.cateMap = cateMap; + } + abstract encode(x: Tensor | number[] | string[]): Promise; + abstract decode(x: Tensor): Promise; +} + /** * Encode categorical features as a one-hot numeric array. * */ -export class OneHotEncoder { - public categories: Tensor; - public cateMap: CateMap; +export class OneHotEncoder extends EncoderBase{ public drop: OneHotDropTypes; /** @@ -39,31 +66,10 @@ export class OneHotEncoder { * categories. */ public constructor (params: OneHotEncoderParams = { drop: 'none' }) { + super(); const { drop } = params; this.drop = drop; } - /** - * Init one-hot encoder - * @param x data input used to init encoder - * @param categories user input categories - */ - public async init(x: Tensor | number[] | string[]): Promise { - const { values } = unique(x); - if (values.dtype === 'int32' || values.dtype === 'float32') { - this.categories = topk(values, values.shape[0], false).values; - } else if (values.dtype === 'bool') { - this.categories = tensor([ false, true ]); - } else { - this.categories = values; - } - const cateData = await this.categories.data(); - const cateMap: CateMap = {}; - for (let i = 0; i < cateData.length; i++) { - const key = cateData[i]; - cateMap[key] = i; - } - this.cateMap = cateMap; - } /** * Encode a given feature into one-hot format @@ -75,7 +81,7 @@ export class OneHotEncoder { throw TypeError('Please init encoder using init()'); } const xTensor = checkArray(x, 'any', 1); - const xData = await xTensor.dataSync(); + const xData = await xTensor.data(); const nCate = this.categories.shape[0]; const xInd = xData.map((d: number|string) => this.cateMap[d]); if (this.drop === 'binary-only' && nCate === 2) { @@ -126,3 +132,38 @@ export class OneHotEncoder { return reshape(stack(cateTensors), [ -1 ]); } } + +export class LabelEncoder extends EncoderBase { + /** + * Encode a given feature into one-hot format + * @param x feature array need to encode + * @returns transformed one-hot feature + */ + public async encode(x: Tensor | number[] | string[]): Promise { + if (!this.categories) { + throw TypeError('Please init encoder using init()'); + } + const xTensor = checkArray(x, 'any', 1); + const xData = await xTensor.data(); + xTensor.dispose(); + return tensor(xData.map((d: number|string) => this.cateMap[d])); + } + /** + * Decode a label one-hot array to original category array + * @param x encoded data need to transform + * @returns transformed category data + */ + public async decode(x: Tensor | number[]): Promise { + if (!this.categories) { + throw TypeError('Please init encoder using init()'); + } + const xData: number[] = x instanceof Tensor ? await (x as Tensor1D).array() : x; + const cateTensors: Tensor[] = []; + xData.forEach((ind: number) => { + cateTensors.push(slice(this.categories, ind, 1)); + }); + return tidy(() => { + return reshape(stack(cateTensors), [ -1 ]); + }); + } +} diff --git a/src/preprocess/index.ts b/src/preprocess/index.ts index f9a0627e..ea118d64 100644 --- a/src/preprocess/index.ts +++ b/src/preprocess/index.ts @@ -1 +1 @@ -export { OneHotEncoder } from './encoder'; +export { OneHotEncoder, LabelEncoder } from './encoder'; diff --git a/test/node/metrics/classification.ts b/test/node/metrics/classification.ts new file mode 100644 index 00000000..1b2c63af --- /dev/null +++ b/test/node/metrics/classification.ts @@ -0,0 +1,72 @@ + +import { getConfusionMatrix, getPrecisionScore, getRecallScore, getF1Score, getF1Scores, getPrecisionScores, getRecallScores, getClassificationReport } from '../../../src/metrics/classifier'; +import { assert } from 'chai'; +import * as tf from '@tensorflow/tfjs-core'; +import { numEqual } from '../../../src/math/utils'; +import { tensorEqual } from '../../../src/linalg'; + + +describe('Metrics', () => { + it('get confusion metrics', async () => { + const yTure = [ 0, 1, 2, 3 ]; + const yPred = [ 0, 1, 2, 3 ]; + const { confusionMatrix } = await getConfusionMatrix(yTure, yPred); + confusionMatrix.print(); + }); + it('precision score', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const precisionScore = await getPrecisionScore(yTure, yPred); + assert.isTrue(numEqual(precisionScore, 0.4, 1e-3)); + const precisionScoreMacro = await getPrecisionScore(yTure, yPred, 'macro'); + assert.isTrue(numEqual(precisionScoreMacro, 0.375, 1e-3)); + const precisionScoreWeighted = await getPrecisionScore(yTure, yPred, 'weighted'); + assert.isTrue(numEqual(precisionScoreWeighted, 0.3, 1e-3)); + }); + it('recall score', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const recallScore = await getRecallScore(yTure, yPred); + assert.isTrue(numEqual(recallScore, 0.4, 1e-3)); + const recallScoreMacro = await getRecallScore(yTure, yPred, 'macro'); + assert.isTrue(numEqual(recallScoreMacro, 0.5, 1e-3)); + const recallScoreWeighted = await getRecallScore(yTure, yPred, 'weighted'); + assert.isTrue(numEqual(recallScoreWeighted, 0.4, 1e-3)); + }); + it('f1 score', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const f1Score = await getF1Score(yTure, yPred); + assert.isTrue(numEqual(f1Score, 0.4, 1e-3)); + const f1ScoreMacro = await getF1Score(yTure, yPred, 'macro'); + assert.isTrue(numEqual(f1ScoreMacro, 0.416, 1e-3)); + const f1ScoreWeighted = await getF1Score(yTure, yPred, 'weighted'); + assert.isTrue(numEqual(f1ScoreWeighted, 0.333, 1e-3)); + }); + it('precision scores', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const { precisions } = await getPrecisionScores(yTure, yPred); + assert.isTrue(tensorEqual(precisions, tf.tensor([ 1, 0.5, 0, 0 ]))); + }); + it('recall scores', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const { recalls } = await getRecallScores(yTure, yPred); + assert.isTrue(tensorEqual(recalls, tf.tensor([ 1, 1, 0, 0 ]))); + }); + it('f1 scores', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const { f1s } = await getF1Scores(yTure, yPred); + assert.isTrue(tensorEqual(f1s, tf.tensor([ 1, 0.667, 0, 0 ]), 1e-3)); + }); + it('classification report', async () => { + const yTure = [ 1, 2, 3, 4, 3 ]; + const yPred = [ 1, 2, 4, 3, 2 ]; + const { f1s, precisions, recalls } = await getClassificationReport(yTure, yPred); + assert.isTrue(tensorEqual(f1s, tf.tensor([ 1, 0.667, 0, 0 ]), 1e-3)); + assert.isTrue(tensorEqual(precisions, tf.tensor([ 1, 0.5, 0, 0 ]), 1e-3)); + assert.isTrue(tensorEqual(recalls, tf.tensor([ 1, 1, 0, 0 ]), 1e-3)); + }); +}); From 99d85077cb41057da6287ba38b3a94473420bb00 Mon Sep 17 00:00:00 2001 From: sugarspectre Date: Thu, 10 Feb 2022 13:09:47 +0800 Subject: [PATCH 2/2] fix lint --- src/metrics/classifier.ts | 8 ++++---- src/preprocess/encoder.ts | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/metrics/classifier.ts b/src/metrics/classifier.ts index 7806c036..60fd730d 100644 --- a/src/metrics/classifier.ts +++ b/src/metrics/classifier.ts @@ -245,10 +245,10 @@ export const getClassificationReport = async (yTrue: Tensor | string[] | number[ const averageRecall = getRecallScoreByConfusionMatrix(confusionMatrix, average); const averageF1 = getF1ScoreByConfusionMatrix(confusionMatrix, average); return { - precisions: precisions, - recalls: recalls, - f1s: f1s, - confusionMatrix: confusionMatrix, + precisions, + recalls, + f1s, + confusionMatrix, categories, accuracy, averageF1, diff --git a/src/preprocess/encoder.ts b/src/preprocess/encoder.ts index 3ab7950b..c5bffc6b 100644 --- a/src/preprocess/encoder.ts +++ b/src/preprocess/encoder.ts @@ -1,6 +1,6 @@ -import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk, Tensor1D, tidy } from "@tensorflow/tfjs-core"; -import { checkArray } from "../utils/validation"; -import { checkShape } from "../linalg/utils"; +import { Tensor, unique, oneHot, cast, tensor, argMax, reshape, slice, stack, sub, squeeze, greaterEqual, topk, Tensor1D, tidy } from '@tensorflow/tfjs-core'; +import { checkArray } from '../utils/validation'; +import { checkShape } from '../linalg/utils'; export type CateMap = { [ key: string ]: number @@ -44,7 +44,7 @@ export abstract class EncoderBase { * Encode categorical features as a one-hot numeric array. * */ -export class OneHotEncoder extends EncoderBase{ +export class OneHotEncoder extends EncoderBase { public drop: OneHotDropTypes; /**