diff --git a/animated-transformer/src/app/sae/sae.component.html b/animated-transformer/src/app/sae/sae.component.html
index d9b04df..5c7ddf2 100644
--- a/animated-transformer/src/app/sae/sae.component.html
+++ b/animated-transformer/src/app/sae/sae.component.html
@@ -14,4 +14,34 @@
{{status}}
}
+
+ @if (trained) {
+
+ @if (learnedFeatureActivationFrequencies.length) {
+
+
Frequencies (avg = {{averageLearnedFeatureActivationFrequency}}):
+
+ {{i}}: {{item}}
+
+
+ }
+
+
+
Interpret a feature
+
+ @if (topActivationsForUserInputFeature) {
+
+
Top activating data for neuron {{userInput}}:
+
+ {{i}}: {{item.value | number:'1.2-2'}} {{item.token}} (Pos {{item.tokenPos}}: {{item.sequence}})
+
+
+ }
+
+
+ }
\ No newline at end of file
diff --git a/animated-transformer/src/app/sae/sae.component.scss b/animated-transformer/src/app/sae/sae.component.scss
index 3ffdc37..ecf3828 100644
--- a/animated-transformer/src/app/sae/sae.component.scss
+++ b/animated-transformer/src/app/sae/sae.component.scss
@@ -17,4 +17,13 @@
.trainer .status {
margin-top: 20px;
+}
+
+.interpreter, .frequencies {
+ margin-top: 20px;
+}
+
+.results {
+ display: flex;
+ gap: 20px;
}
\ No newline at end of file
diff --git a/animated-transformer/src/app/sae/sae.component.ts b/animated-transformer/src/app/sae/sae.component.ts
index aa12820..020433c 100644
--- a/animated-transformer/src/app/sae/sae.component.ts
+++ b/animated-transformer/src/app/sae/sae.component.ts
@@ -1,4 +1,4 @@
-/* Copyright 2023 Google LLC. All Rights Reserved.
+/* Copyright 2024 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@ limitations under the License.
import { AfterViewInit, Component, OnInit } from '@angular/core';
import * as tf from '@tensorflow/tfjs';
+import { CommonModule } from '@angular/common';
+import { FormsModule } from '@angular/forms';
import { computeTransformer, initDecoderParams } from '../../lib/transformer/transformer_gtensor';
import * as gtensor from '../../lib/gtensor/gtensor';
import { gtensorTrees } from '../../lib/gtensor/gtensor_tree';
@@ -29,14 +31,29 @@ import { BasicLmTask, BasicLmTaskUpdate } from 'src/lib/seqtasks/util';
import { MatButtonModule } from '@angular/material/button';
+const MLP_ACT_SIZE = 8;
+const DICTIONARY_MULTIPLIER = 4;
+const D_HIDDEN = MLP_ACT_SIZE * DICTIONARY_MULTIPLIER; // learned feature size
+const L1_COEFF = 0.003;
+
@Component({
selector: 'app-sae',
+ standalone: true,
templateUrl: './sae.component.html',
- styleUrls: ['./sae.component.scss']
+ styleUrls: ['./sae.component.scss'],
+ imports: [CommonModule, MatButtonModule, FormsModule],
})
export class SAEComponent {
status: string = '';
+ public saeModel: any;
public trainingData: any;
+ public trainingInputs: any;
+ public trained = false;
+ learnedFeatureActivationFrequencies: number[] = [];
+ averageLearnedFeatureActivationFrequency: number = 0;
+ predictedDictionaryFeatures: any;
+ topActivationsForUserInputFeature: any;
+ userInput: any;
constructor(
private route: ActivatedRoute,
private router: Router,
@@ -59,74 +76,117 @@ export class SAEComponent {
reader.readAsText(file);
}
+ async interpret() {
+ const activationsForFeatureToInspect = Array.from(
+ this.predictedDictionaryFeatures.slice([0, this.userInput], [-1, 1]).dataSync());
+ const indexedActivations = activationsForFeatureToInspect.map((value, index) => ({ value, index }));
+ indexedActivations.sort((a: any, b: any) => {
+ if (a.value < b.value) {
+ return 1;
+ }
+ return -1;
+ });
+
+ const nTop = 50;
+ this.topActivationsForUserInputFeature = indexedActivations
+ .slice(0, nTop).map((item: any) => {
+ const trainingInput = this.trainingInputs[item.index];
+ return {
+ 'value': item.value,
+ ...trainingInput
+ };
+ });
+ }
+
async train() {
tf.util.shuffle(this.trainingData);
- const nTrainingData = this.trainingData.length;
- const trueActivations = tf.concat(this.trainingData
+ // For each sequence, create a dict out of each token in that sequence with metadata (the token itself, its index in the sequence, and the sequence).
+ this.trainingInputs = this.trainingData.map((item: any) =>
+ item.input.map((d: any, i: number) => ({
+ 'token': d,
+ 'sequence': item.input,
+ 'tokenPos': i
+ })))
+ .reduce((acc: any, curr: any) => acc.concat(curr), []); // flatten.
+
+ this.trainingData = tf.concat(this.trainingData
.map((item: any) => tf.tensor(item.mlpOutputs.data, item.mlpOutputs.shape).squeeze()));
-
- const mlpActSize = 8;
- const dictionaryMultiplier = 4;
- const dHidden = mlpActSize * dictionaryMultiplier; // learned feature size
- const l1Coeff = 0.0003;
+ const nTrainingData = this.trainingData.shape[0];
const inputs = tf.input({
- shape: [mlpActSize],
+ shape: [MLP_ACT_SIZE],
name: 'sae_input'
});
// const inputBias = tf.input({
- // shape: [mlpActSize],
+ // shape: [MLP_ACT_SIZE],
// name: 'sae_input_bias'
// });
// const biasedInput = tf.layers.add().apply([inputs, inputBias]);
- const activations = tf.layers.dense({
- units: dHidden,
+ const dictionaryFeatures = tf.layers.dense({
+ units: D_HIDDEN,
useBias: true,
activation: 'relu',
}).apply(inputs) as any;
const reconstruction = tf.layers.dense({
- units: mlpActSize,
+ units: MLP_ACT_SIZE,
useBias: true,
- }).apply(activations) as any;
+ }).apply(dictionaryFeatures) as any;
- // Adding a layer to concatenate activations to the reconstruction as final output so both are available in the loss function as yPred, because intermediate activations are needed to compute L1 loss term.
+ // Adding a layer to concatenate dictionaryFeatures to the reconstruction as final output so both are available in the loss function as yPred, because intermediate dictionaryFeatures are needed to compute L1 loss term.
// Alternatives tried:
// - Retrieving intermediate output in the loss function - couldn't figure out how to retrieve as a non-symbolic tensor
// - Outputting multiple tensors in the model - but yPred in the loss function is still only the first output tensor
- const combinedOutput = tf.layers.concatenate({axis: 1}).apply([activations, reconstruction]) as any;
- const saeModel = tf.model({inputs: [inputs], outputs: [combinedOutput]});
+ const combinedOutput = tf.layers.concatenate({axis: 1}).apply([dictionaryFeatures, reconstruction]) as any;
+ this.saeModel = tf.model({inputs: [inputs], outputs: [combinedOutput]});
- saeModel.compile({
+ this.saeModel.compile({
optimizer: tf.train.adam(),
loss: (yTrue: tf.Tensor, yPred: tf.Tensor) => {
- const outputActivations = yPred.slice([0, 0], [-1, dHidden]);
- const outputReconstruction = yPred.slice([0, dHidden], [-1, -1]);
- const trueReconstruction = yTrue.slice([0, dHidden], [-1, -1]);
+ const outputDictionaryFeatures = yPred.slice([0, 0], [-1, D_HIDDEN]);
+ const outputReconstruction = yPred.slice([0, D_HIDDEN], [-1, -1]);
+ const trueReconstruction = yTrue.slice([0, D_HIDDEN], [-1, -1]);
const l2Loss = tf.losses.meanSquaredError(trueReconstruction, outputReconstruction);
- const l1Loss = tf.mul(l1Coeff, tf.sum(tf.abs(outputActivations)));
+ const l1Loss = tf.mul(L1_COEFF, tf.sum(tf.abs(outputDictionaryFeatures)));
return tf.add(l2Loss, l1Loss);
},
});
const epochSize = 8;
// This tensor is unused - it's just to make yTrue shape match the concatenated output.
- const placeholderActivationsTensor = tf.randomNormal([epochSize, dHidden]);
+ const placeholderDictionaryFeatures = tf.randomNormal([epochSize, D_HIDDEN]);
for (let i=0; i