Skip to content

Temporary fix for rms norm backward on CPU. #1197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Alkamist
Copy link

From my testing, ggml_compute_forward_rms_norm_back_f32 does not function properly on the CPU backend.

I originally noticed because CUDA seemed to be working for what I was doing, but the CPU version was completely unstable.

I made a very quick and dirty fix that seems to work from my testing, but probably shouldn't just be pulled in as is, I just wanted to call attention to ggml_compute_forward_rms_norm_back_f32 not working properly and provide a temporary fix that seems to work for specifically what I'm doing at the moment.

@JohannesGaessler JohannesGaessler self-requested a review April 23, 2025 07:37
@JohannesGaessler
Copy link
Collaborator

In an abstract mathematical sense, the code on master and with this PR should be equivalent. However, there could be numerical issues. In what context did you encounter this problem? Did you investigate the values that the changed variables can assume? How did you discover that a change to RMS norm fixes the issue?

@Alkamist
Copy link
Author

Alkamist commented Apr 23, 2025

In an abstract mathematical sense, the code on master and with this PR should be equivalent. However, there could be numerical issues. In what context did you encounter this problem? Did you investigate the values that the changed variables can assume? How did you discover that a change to RMS norm fixes the issue?

I'm using a custom binding and wrapper I made in the Odin programming language, so it may be hard to provide a concrete example that translates 1 to 1.

It's entirely possible that the problem is some subtle interaction with my wrapper, but I think it is unlikely.

Basically the context is that I was experimenting with training some basic transformer architecture on text, and when I used CUDA, the model would learn, but when I used the CPU, the loss would fluctuate wildly and eventually rise.

I then started investigating my transformer implementation, and through commenting things out and investigating how it affected the loss, I was able to conclude that the problem was coming from RMS norm.

I just assumed the problem was in the backward pass, so I compared the CPU backward function to the one in the CUDA code, and through trial and error came up with what is in this pull request, which for whatever reason fixes the issue for me.

I am currently using a gallocr and allocating on a single device using the backend API, but even if I used a backend_sched it didn't really seem to matter, I got the same issues.

I'll post my basic architecture for reference, but again, it's a wrapper around GGML so things aren't completely 1 to 1, and it's also in a different language, so I'm not sure how helpful it will be.

attention :: proc(graph: ^ml.Graph, q, k, v, kq_mask, positions: ^ml.Tensor, kq_scale: f32) -> (res: ^ml.Tensor) {
	q := ml.reshape(graph, q, EMBEDDING_DIMENSIONS / ATTENTION_HEADS, ATTENTION_HEADS, MAX_TOKENS)
	k := ml.reshape(graph, k, EMBEDDING_DIMENSIONS / ATTENTION_HEADS, ATTENTION_HEADS, MAX_TOKENS)
	v := ml.reshape(graph, v, EMBEDDING_DIMENSIONS / ATTENTION_HEADS, ATTENTION_HEADS, MAX_TOKENS)

	q = ml.rope(graph, q, positions, EMBEDDING_DIMENSIONS / ATTENTION_HEADS, .Normal)
	k = ml.rope(graph, k, positions, EMBEDDING_DIMENSIONS / ATTENTION_HEADS, .Normal)

	ml.forward(graph, q)
	ml.forward(graph, k)
	ml.forward(graph, v)

	q = ml.permute(graph, q, 0, 2, 1, 3)
	k = ml.permute(graph, k, 0, 2, 1, 3)
	v = ml.permute(graph, v, 0, 2, 1, 3)

	kq := ml.matmul(graph, k, q)
	kq  = ml.softmax(graph, kq, kq_mask, kq_scale)

	v = ml.contiguous(graph, ml.transpose(graph, v))

	kqv := ml.matmul(graph, v, kq)
	kqv  = ml.permute(graph, kqv, 0, 2, 1, 3)

	res = ml.contiguous(graph, kqv, EMBEDDING_DIMENSIONS, MAX_TOKENS)

	ml.forward(graph, res)

	return
}
rms_layernorm :: proc(graph: ^ml.Graph, t, weight: ^ml.Tensor, epsilon: f32 = 1e-5) -> (res: ^ml.Tensor) {
	res = ml.rms_norm(graph, t, epsilon)
	res = ml.mul(graph, res, weight)
	return
}
swiglu :: proc(graph: ^ml.Graph, t, w0, w1, w2: ^ml.Tensor) -> ^ml.Tensor {
	x0     := ml.matmul(graph, w0, t)
	x1     := ml.matmul(graph, w1, t)
	hidden := ml.mul(graph, ml.silu(graph, x0), x1)
	return ml.matmul(graph, w2, hidden)
}
forward :: proc(model: ^Model, graph: ^ml.Graph, tokens, positions: ^ml.Tensor) -> ^ml.Tensor {
	t := ml.get_rows(graph, model.token_embeddings, tokens)

	residual := t

	kq_scale := 1.0 / math.sqrt(f32(EMBEDDING_DIMENSIONS) / ATTENTION_HEADS)

	for layer, i in model.layers {
		t = rms_layernorm(graph, residual, layer.norm0_weight)

		q := ml.matmul(graph, layer.q_weight, t)
		k := ml.matmul(graph, layer.k_weight, t)
		v := ml.matmul(graph, layer.v_weight, t)
		t  = attention(graph, q, k, v, model.attention_mask, positions, kq_scale)
		t  = ml.matmul(graph, layer.proj_weight, t)

		residual = ml.add(graph, residual, t)

		t = rms_layernorm(graph, residual, layer.norm1_weight)

		t = swiglu(graph, t, layer.swiglu0_weight, layer.swiglu1_weight, layer.swiglu2_weight)

		residual = ml.add(graph, residual, t)
	}

	t = rms_layernorm(graph, residual, model.norm_weight)

	t = ml.matmul(graph, model.token_embeddings, t)

	ml.set_output(t)
	ml.forward(graph, t)

	return t
}

@JohannesGaessler
Copy link
Collaborator

Alright, thank you for the information. If I had to guess this problem is caused by numerical issues. Unfortunately I don't have a good understanding of which implementation would in general have fewer issues beyond your specific use case. The progress on language model training in llama.cpp is currently stalled; I would keep this issue in mind and revisit it in that context once I have a concrete use case based on which I can investigate and debug it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants