Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-www committed Jan 29, 2022
2 parents 897d961 + 2929263 commit a10dccc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/metrics/classifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export type ClassificationReport = {
recalls: Tensor;
f1s: Tensor;
confusionMatrix: Tensor;
classes: Tensor;
categories: Tensor;
accuracy: number;
averagePrecision: number;
averageRecall: number;
Expand Down Expand Up @@ -249,7 +249,7 @@ export const getClassificationReport = async (yTrue: Tensor | string[] | number[
recalls: recalls,
f1s: f1s,
confusionMatrix: confusionMatrix,
classes: categories,
categories,
accuracy,
averageF1,
averagePrecision,
Expand Down
57 changes: 36 additions & 21 deletions test/node/preprocess/encoder.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import { eigenSolve } from '../../../src/linalg/eigen';
import '@tensorflow/tfjs-backend-cpu';
import * as tf from '@tensorflow/tfjs-core';
import { tensorEqual } from '../../../src/linalg/utils';
import { assert } from 'chai';
import 'mocha';
import { OneHotEncoder } from '../../../src/preprocess';
import { OneHotEncoder, LabelEncoder } from '../../../src/preprocess';

const x = ['tree', 'apple', 'banana', 'tree', 'apple', 'banana'];

const x = [ 'tree', 'apple', 'banana', 'tree', 'apple', 'banana' ];
const xEncode = tf.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[ 1, 0, 0 ],
[ 0, 1, 0 ],
[ 0, 0, 1 ],
[ 1, 0, 0 ],
[ 0, 1, 0 ],
[ 0, 0, 1 ]
]);
const xLabelEncode = tf.tensor([ 0, 1, 2, 0, 1, 2 ]);
const xEncodeDrop = tf.tensor([
[0, 0],
[1, 0],
[0, 1],
[0, 0],
[1, 0],
[0, 1],
[ 0, 0 ],
[ 1, 0 ],
[ 0, 1 ],
[ 0, 0 ],
[ 1, 0 ],
[ 0, 1 ]
]);
const bx = ['tree', 'apple', 'tree', 'apple'];
const bxEncodeDrop = tf.tensor([0, 1, 0, 1]);
const bx = [ 'tree', 'apple', 'tree', 'apple' ];
const bxEncodeDrop = tf.tensor([ 0, 1, 0, 1 ]);

describe('OneHot Encoder', () => {
it('encode', async () => {
Expand All @@ -40,21 +40,36 @@ describe('OneHot Encoder', () => {
assert.deepEqual(xCate.dataSync() as any, x);
});
it('encode drop first', async () => {
const encoder = new OneHotEncoder({drop: 'first'});
const encoder = new OneHotEncoder({ drop: 'first' });
await encoder.init(x);
const xOneHot = await encoder.encode(x);
assert.deepEqual(xOneHot.dataSync(), xEncodeDrop.dataSync());
});
it('encode binary only', async () => {
const encoder = new OneHotEncoder({drop: 'binary-only'});
const encoder = new OneHotEncoder({ drop: 'binary-only' });
await encoder.init(bx);
const bxOneHot = await encoder.encode(bx);
assert.deepEqual(bxOneHot.dataSync(), bxEncodeDrop.dataSync());
});
it('decode binary only', async () => {
const encoder = new OneHotEncoder({drop: 'binary-only'});
const encoder = new OneHotEncoder({ drop: 'binary-only' });
await encoder.init(bx);
const bxCate = await encoder.decode(bxEncodeDrop);
assert.deepEqual(bxCate.dataSync() as any, bx);
});
});

describe('Label Encoder', () => {
it('encode', async () => {
const encoder = new LabelEncoder();
await encoder.init(x);
const xEncode = await encoder.encode(x);
assert.deepEqual(xEncode.dataSync(), xLabelEncode.dataSync());
});
it('decode', async () => {
const encoder = new LabelEncoder();
await encoder.init(x);
const xDecode = await encoder.decode(xLabelEncode);
assert.deepEqual(x, xDecode.dataSync() as any);
});
});

0 comments on commit a10dccc

Please sign in to comment.