diff --git a/animated-transformer/src/app/sae/sae.component.html b/animated-transformer/src/app/sae/sae.component.html index d9b04df..5c7ddf2 100644 --- a/animated-transformer/src/app/sae/sae.component.html +++ b/animated-transformer/src/app/sae/sae.component.html @@ -14,4 +14,34 @@
{{status}}
} + + @if (trained) { +
+ @if (learnedFeatureActivationFrequencies.length) { +
+
Frequencies (avg = {{averageLearnedFeatureActivationFrequency}}):
+
+ {{i}}: {{item}} +
+
+ } + +
+
Interpret a feature
+
+ + + +
+ @if (topActivationsForUserInputFeature) { +
+
Top activating data for neuron {{userInput}}:
+
+ {{i}}: {{item.value | number:'1.2-2'}} {{item.token}} (Pos {{item.tokenPos}}: {{item.sequence}}) +
+
+ } +
+
+ } \ No newline at end of file diff --git a/animated-transformer/src/app/sae/sae.component.scss b/animated-transformer/src/app/sae/sae.component.scss index 3ffdc37..ecf3828 100644 --- a/animated-transformer/src/app/sae/sae.component.scss +++ b/animated-transformer/src/app/sae/sae.component.scss @@ -17,4 +17,13 @@ .trainer .status { margin-top: 20px; +} + +.interpreter, .frequencies { + margin-top: 20px; +} + +.results { + display: flex; + gap: 20px; } \ No newline at end of file diff --git a/animated-transformer/src/app/sae/sae.component.ts b/animated-transformer/src/app/sae/sae.component.ts index aa12820..020433c 100644 --- a/animated-transformer/src/app/sae/sae.component.ts +++ b/animated-transformer/src/app/sae/sae.component.ts @@ -1,4 +1,4 @@ -/* Copyright 2023 Google LLC. All Rights Reserved. +/* Copyright 2024 Google LLC. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ limitations under the License. import { AfterViewInit, Component, OnInit } from '@angular/core'; import * as tf from '@tensorflow/tfjs'; +import { CommonModule } from '@angular/common'; +import { FormsModule } from '@angular/forms'; import { computeTransformer, initDecoderParams } from '../../lib/transformer/transformer_gtensor'; import * as gtensor from '../../lib/gtensor/gtensor'; import { gtensorTrees } from '../../lib/gtensor/gtensor_tree'; @@ -29,14 +31,29 @@ import { BasicLmTask, BasicLmTaskUpdate } from 'src/lib/seqtasks/util'; import { MatButtonModule } from '@angular/material/button'; +const MLP_ACT_SIZE = 8; +const DICTIONARY_MULTIPLIER = 4; +const D_HIDDEN = MLP_ACT_SIZE * DICTIONARY_MULTIPLIER; // learned feature size +const L1_COEFF = 0.003; + @Component({ selector: 'app-sae', + standalone: true, templateUrl: './sae.component.html', - styleUrls: ['./sae.component.scss'] + styleUrls: ['./sae.component.scss'], + imports: [CommonModule, MatButtonModule, FormsModule], }) export class SAEComponent { status: string = ''; + public saeModel: any; public trainingData: any; + public trainingInputs: any; + public trained = false; + learnedFeatureActivationFrequencies: number[] = []; + averageLearnedFeatureActivationFrequency: number = 0; + predictedDictionaryFeatures: any; + topActivationsForUserInputFeature: any; + userInput: any; constructor( private route: ActivatedRoute, private router: Router, @@ -59,74 +76,117 @@ export class SAEComponent { reader.readAsText(file); } + async interpret() { + const activationsForFeatureToInspect = Array.from( + this.predictedDictionaryFeatures.slice([0, this.userInput], [-1, 1]).dataSync()); + const indexedActivations = activationsForFeatureToInspect.map((value, index) => ({ value, index })); + indexedActivations.sort((a: any, b: any) => { + if (a.value < b.value) { + return 1; + } + return -1; + }); + + const nTop = 50; + this.topActivationsForUserInputFeature = indexedActivations + .slice(0, nTop).map((item: any) => { + const trainingInput = this.trainingInputs[item.index]; + return { + 'value': item.value, + ...trainingInput + }; + }); + } + async train() { tf.util.shuffle(this.trainingData); - const nTrainingData = this.trainingData.length; - const trueActivations = tf.concat(this.trainingData + // For each sequence, create a dict out of each token in that sequence with metadata (the token itself, its index in the sequence, and the sequence). + this.trainingInputs = this.trainingData.map((item: any) => + item.input.map((d: any, i: number) => ({ + 'token': d, + 'sequence': item.input, + 'tokenPos': i + }))) + .reduce((acc: any, curr: any) => acc.concat(curr), []); // flatten. + + this.trainingData = tf.concat(this.trainingData .map((item: any) => tf.tensor(item.mlpOutputs.data, item.mlpOutputs.shape).squeeze())); - - const mlpActSize = 8; - const dictionaryMultiplier = 4; - const dHidden = mlpActSize * dictionaryMultiplier; // learned feature size - const l1Coeff = 0.0003; + const nTrainingData = this.trainingData.shape[0]; const inputs = tf.input({ - shape: [mlpActSize], + shape: [MLP_ACT_SIZE], name: 'sae_input' }); // const inputBias = tf.input({ - // shape: [mlpActSize], + // shape: [MLP_ACT_SIZE], // name: 'sae_input_bias' // }); // const biasedInput = tf.layers.add().apply([inputs, inputBias]); - const activations = tf.layers.dense({ - units: dHidden, + const dictionaryFeatures = tf.layers.dense({ + units: D_HIDDEN, useBias: true, activation: 'relu', }).apply(inputs) as any; const reconstruction = tf.layers.dense({ - units: mlpActSize, + units: MLP_ACT_SIZE, useBias: true, - }).apply(activations) as any; + }).apply(dictionaryFeatures) as any; - // Adding a layer to concatenate activations to the reconstruction as final output so both are available in the loss function as yPred, because intermediate activations are needed to compute L1 loss term. + // Adding a layer to concatenate dictionaryFeatures to the reconstruction as final output so both are available in the loss function as yPred, because intermediate dictionaryFeatures are needed to compute L1 loss term. // Alternatives tried: // - Retrieving intermediate output in the loss function - couldn't figure out how to retrieve as a non-symbolic tensor // - Outputting multiple tensors in the model - but yPred in the loss function is still only the first output tensor - const combinedOutput = tf.layers.concatenate({axis: 1}).apply([activations, reconstruction]) as any; - const saeModel = tf.model({inputs: [inputs], outputs: [combinedOutput]}); + const combinedOutput = tf.layers.concatenate({axis: 1}).apply([dictionaryFeatures, reconstruction]) as any; + this.saeModel = tf.model({inputs: [inputs], outputs: [combinedOutput]}); - saeModel.compile({ + this.saeModel.compile({ optimizer: tf.train.adam(), loss: (yTrue: tf.Tensor, yPred: tf.Tensor) => { - const outputActivations = yPred.slice([0, 0], [-1, dHidden]); - const outputReconstruction = yPred.slice([0, dHidden], [-1, -1]); - const trueReconstruction = yTrue.slice([0, dHidden], [-1, -1]); + const outputDictionaryFeatures = yPred.slice([0, 0], [-1, D_HIDDEN]); + const outputReconstruction = yPred.slice([0, D_HIDDEN], [-1, -1]); + const trueReconstruction = yTrue.slice([0, D_HIDDEN], [-1, -1]); const l2Loss = tf.losses.meanSquaredError(trueReconstruction, outputReconstruction); - const l1Loss = tf.mul(l1Coeff, tf.sum(tf.abs(outputActivations))); + const l1Loss = tf.mul(L1_COEFF, tf.sum(tf.abs(outputDictionaryFeatures))); return tf.add(l2Loss, l1Loss); }, }); const epochSize = 8; // This tensor is unused - it's just to make yTrue shape match the concatenated output. - const placeholderActivationsTensor = tf.randomNormal([epochSize, dHidden]); + const placeholderDictionaryFeatures = tf.randomNormal([epochSize, D_HIDDEN]); for (let i=0; i