It comes down to tensor shape. 2D = good, 3D = bad. Relevant commit: https://github.com/shawwn/gpt-2/commit/4d766e9629f28732df615e1dd4e2d3174f3cf703