My value_and_grad
taking longer than expected (I'm sure I'm doing something wrong)
#17183
-
Hi! I'm currently trying to refactor this google research repo for my own project. As a first goal, I want to reproduce the results on Cifar10 with the ResNet FRN model they use. However, I feel like I'm not fully optimizing the single GPU I have. I've attached a (lengthy) bit of code demonstrating what I want to do, and how it seems like each call to My question is: is there anything I'm not doing to fully utilize the GPU to maximize training speed? I've also attached results between using the CPU and the single GPU (it is a Tesla V100). Thank you very much in advance! I know this is quite an involved question so I appreciate it. import jax
import jax.numpy as jnp
import haiku as hk
import time
import math
from dataclasses import dataclass
from typing import Callable, Dict, Tuple
from jax.random import PRNGKey
import tensorflow_datasets as tfds
TFLoader = tfds.core.dataset_utils._IterableDataset
# jax.config.update('jax_platform_name', 'cpu')
# // Model ----------------------------------------
he_normal = hk.initializers.VarianceScaling(2.0, "fan_in", "truncated_normal")
class FeatureResponseNorm(hk.Module):
def __init__(self, eps=1e-6, name="frn"):
super().__init__(name=name)
self.eps = eps
def __call__(self, x, **unused_kwargs):
del unused_kwargs
par_shape = (1, 1, 1, x.shape[-1]) # [1,1,1,C]
tau = hk.get_parameter("tau", par_shape, x.dtype, init=jnp.zeros)
beta = hk.get_parameter("beta", par_shape, x.dtype, init=jnp.zeros)
gamma = hk.get_parameter("gamma", par_shape, x.dtype, init=jnp.ones)
nu2 = jnp.mean(jnp.square(x), axis=[1, 2], keepdims=True)
x = x * jax.lax.rsqrt(nu2 + self.eps)
y = gamma * x + beta
z = jnp.maximum(y, tau)
return z
def _resnet_layer(
inputs,
num_filters,
normalization_layer,
kernel_size=3,
strides=1,
activation=lambda x: x,
use_bias=True,
is_training=True,
):
x = inputs
x = hk.Conv2D(
num_filters,
kernel_size,
stride=strides,
padding="same",
w_init=he_normal,
with_bias=use_bias,
)(x)
x = normalization_layer()(x, is_training=is_training)
x = activation(x)
return x
def make_resnet_fn(
num_classes,
depth,
normalization_layer,
width=16,
use_bias=True,
activation=jax.nn.relu,
):
num_res_blocks = (depth - 2) // 6
if (depth - 2) % 6 != 0:
raise ValueError("depth must be 6n+2 (e.g. 20, 32, 44).")
def forward(x, is_training):
num_filters = width
#x, _ = batch
x = _resnet_layer(
x,
num_filters=num_filters,
activation=activation,
use_bias=use_bias,
normalization_layer=normalization_layer,
)
for stack in range(3):
for res_block in range(num_res_blocks):
strides = 1
if stack > 0 and res_block == 0: # first layer but not first stack
strides = 2 # downsample
y = _resnet_layer(
x,
num_filters=num_filters,
strides=strides,
activation=activation,
use_bias=use_bias,
is_training=is_training,
normalization_layer=normalization_layer,
)
y = _resnet_layer(
y,
num_filters=num_filters,
use_bias=use_bias,
is_training=is_training,
normalization_layer=normalization_layer,
)
if stack > 0 and res_block == 0: # first layer but not first stack
# linear projection residual shortcut connection to match changed dims
x = _resnet_layer(
x,
num_filters=num_filters,
kernel_size=1,
strides=strides,
use_bias=use_bias,
is_training=is_training,
normalization_layer=normalization_layer,
)
x = activation(x + y)
num_filters *= 2
x = hk.AvgPool((8, 8, 1), 8, "VALID")(x)
x = hk.Flatten()(x)
logits = hk.Linear(num_classes, w_init=he_normal)(x)
return logits
return forward
def make_resnet20_frn_fn(data_info, activation=jax.nn.relu):
num_classes = data_info["num_classes"]
net = hk.transform_with_state(make_resnet_fn(
num_classes,
depth=20,
normalization_layer=FeatureResponseNorm,
activation=activation,
))
return net.apply, net.init
# // Core Functions -------------------------------------
def tree_dot(a, b):
return sum(
[
jnp.sum(e1 * e2)
for e1, e2 in zip(
jax.tree_util.tree_leaves(a), jax.tree_util.tree_leaves(b)
)
]
)
def prediction_fn(
net_apply: Callable,
net_state: Dict,
params: Dict,
x: jax.Array,
is_training: bool,
rng: PRNGKey = None,
as_logits: bool = True,
) -> Tuple[jax.Array, Dict]:
"""
Make a prediction with a Haiku model (Transformed class).
Parameters
----------
net_apply : Callable
the Haiku model's apply function
net_state : Dict
the Haiku model's current state
params : Dict
parameter dictionary
x : jax.Array
the inputs
is_training : bool
whether training or not
as_logits : bool, optional
whether to return as logits or softmax outputs, by default True
rng : PRNGKey, optional
rng for prediction, by default None
Returns
-------
Tuple[jax.Array, Dict]
logits/softmax outputs and the net state as a result
of applying params to x.
"""
logits, net_state = net_apply(
params=params,
state=net_state,
rng=rng,
x=x,
is_training=is_training,
)
if as_logits:
return logits, net_state
else:
return jax.nn.softmax(logits), net_state
# // Likelihood --------------------------
@dataclass
class CategoricalLogLikelihood:
"""
Log likelihood function for classification.
Attributes
----------
temperature : float
temperature. Constant with which we divide the likelihood value by.
Methods
-------
value(net_apply, net_state, params, x, y, ...) -> Tuple[float, Dict]
Compute the value of the log likelihood function for a given setting of the
parameters and data x,y. Also returns the net_state due to predictions
grad() -> Callable[..., Tuple[Dict, Dict]]
Return a jax.grad function of the log likelihood function with respect to
the parameters.
value_and_grad() -> Callable[..., Tuple[Tuple[jax.Array, Dict], Dict]]
Return a jax.value_and_grad function of the log likelihood function with
respect to the parameters.
"""
temperature: float
def value(
self,
net_apply: Callable,
net_state: Dict,
params: Dict,
x: jax.Array,
y: jax.Array,
is_training: bool = True,
) -> Tuple[float, Dict]:
"""
Use the net_apply function, the net_state, and parameters to make predictions on
inputs `x`. Then compute the log likelihood with respect to the data `y`.
Parameters
----------
net_apply : Callable
the Haiku model's apply function
net_state : Dict
the Haiku model's current state
params : Dict
parameter dictionary
x : jax.Array
the inputs
y : jax.Array
the labels
is_training : bool, optional
whether training or not, by default True
Returns
-------
Tuple[float, Dict]
The log likelihood evaluated at the parameters under the data, and
the net_state as a result of its apply.
"""
logits, net_state = prediction_fn(
net_apply=net_apply,
net_state=net_state,
params=params,
x=x,
is_training=is_training,
rng=None,
as_logits=True,
)
# Convert labels to one hot
num_classes = logits.shape[-1]
labels_one_hot = jax.nn.one_hot(y, num_classes)
# Compute log likelihood and divide by temp
out = jnp.sum(labels_one_hot * jax.nn.log_softmax(logits)) / self.temperature
return out, net_state
def grad(self) -> Callable[..., Tuple[Dict, Dict]]:
"""
Return a function that computes the gradient of the log likelihood function
with respect to the parameters.
Returns
-------
Callable[..., Tuple[Dict, Dict]]
A function that computes the gradient of the parameters, and the net_state
as a result of the apply in `value`
Usage
-----
likelihood = CategoricalLogLikelihood(temperature=1.0)
(grad, net_state) = likelihood.grad()(
net_apply,
net_state,
params,
x,
y,
is_training,
)
"""
# (params is arg 2 in `value`) -> gradient taken with respect to params.
return jax.grad(self.value, argnums=2, has_aux=True)
def value_and_grad(self) -> Callable[..., Tuple[Tuple[jax.Array, Dict], Dict]]:
"""
Return a function that computes the value and gradient with respect to the
parameters. The returned function itself returns the log likelihood value at
(params, x, y) as well as the gradient with respect to the parameters evaluated
at `params`.
Returns
-------
Callable[..., Tuple[Tuple[jax.Array, Dict], Dict]]
A function that computes the value of the function, the gradient, and the
next_state as a result of the apply in `value`. The output of the function
has the form:
(
(
the log likelihood evaluated at (params, x, y),
the net_state as a result of the apply in `value`)
),
the gradient with respect to params
)
Usage
-----
likelihood = CategoricalLogLikelihood(temperature=1.0)
(likelihood_val, net_state), grad = likelihood.value_and_grad()(
net_apply,
net_state,
params,
x,
y,
is_training
)
"""
# (params is arg 2 in __call__) -> gradient taken with respect to params.
return jax.value_and_grad(self.value, argnums=2, has_aux=True)
@dataclass
class GaussianLogPrior:
"""
Log prior function for a zero-mean Gaussian prior
Attributes
----------
temperature : float
temperature. Constant with which we divide the likelihood value by.
weight_decay : float
weight decay value. Corresponds to the 1/sqrt(variance) of the prior.
"""
temperature: float
weight_decay: float
def value(self, params: Dict) -> float:
"""
Evaluate the log prior at `params`. Note that self.weight_decay
acts as the log variance of the prior. TODO: is this true?
Parameters
----------
params : Dict
a dictionary of parameters
Returns
-------
float
the log prior evaluated at `params` divided by the self.temperature
"""
n_params = sum([p.size for p in jax.tree_util.tree_leaves(params)])
log_prob = -(
0.5 * tree_dot(params, params) * self.weight_decay
+ 0.5 * n_params * jnp.log((2 * math.pi) / self.weight_decay)
)
return log_prob / self.temperature
def grad(self) -> Callable[..., Dict]:
"""
Return a function that computes the gradient of the log prior with respect to
the parameters.
Returns
-------
Dict
A function that computes the gradient of parameters evaluated for a given
`params` value.
"""
return jax.grad(self.value, argnums=0)
def value_and_grad(self) -> Dict:
"""
Return a function that computes the value and the gradient of the log prior
with respect to the parameters.
Returns
-------
Dict
A function that computes the value and gradient of parameters evaluated for
a given `params` value.
"""
return jax.value_and_grad(self.value, argnums=0)
@dataclass
class LogProbability:
likelihood: Callable # but not really
prior: Callable # but not really
def value(
self,
net_apply: Callable,
net_state: Dict,
params: Dict,
x: jax.Array,
y: jax.Array,
is_training: bool = False,
):
likelihood_value, net_state = self.likelihood.value(
net_apply, net_state, params, x=x, y=y, is_training=is_training
)
prior_value = self.prior.value(params)
log_prob_value = likelihood_value + prior_value
return log_prob_value, net_state
def grad(self):
return jax.grad(self.value, argnums=2, has_aux=True)
def value_and_grad(self):
return jax.value_and_grad(self.value, argnums=2, has_aux=True)
# // Main -----------------------------------
key = jax.random.PRNGKey(123)
net_apply, net_init = make_resnet20_frn_fn(data_info={"num_classes": 10})
key, subkey = jax.random.split(key)
init_data = jax.random.normal(subkey, [1, 32, 32, 3]) #, jnp.array([0.])
# // Get initial parameters and state of network
key, subkey = jax.random.split(key)
params, net_state = net_init(subkey, init_data, True)
## // Setup log probability
likelihood_object = CategoricalLogLikelihood(temperature=1.0)
prior_object = GaussianLogPrior(temperature=1.0, weight_decay=5.0)
log_probability = LogProbability(likelihood_object, prior_object)
for i in range(30):
key, subkey = jax.random.split(key)
# Create random features
x = jax.random.normal(subkey, [100, 32, 32, 3])
y = jnp.ones(shape=(100,), dtype=int)
# TIME OP
start = time.time()
(logprob_value, net_state), logprob_grad = log_probability.value_and_grad()(
net_apply,
net_state,
params,
x,
y,
True
)
end = time.time()
print(f"Time elapsed for call {i}: {end-start} seconds") WITH CPU: Time elapsed for call 0: 6.852431058883667 seconds
Time elapsed for call 1: 2.4112114906311035 seconds
Time elapsed for call 2: 1.8837454319000244 seconds
Time elapsed for call 3: 1.9587490558624268 seconds
Time elapsed for call 4: 1.711151123046875 seconds
Time elapsed for call 5: 1.627737283706665 seconds
Time elapsed for call 6: 1.7711141109466553 seconds
Time elapsed for call 7: 1.9205939769744873 seconds
Time elapsed for call 8: 1.7319071292877197 seconds
Time elapsed for call 9: 1.7010533809661865 seconds
Time elapsed for call 10: 1.6617672443389893 seconds
Time elapsed for call 11: 1.7591743469238281 seconds
Time elapsed for call 12: 1.6765027046203613 seconds
Time elapsed for call 13: 1.6800673007965088 seconds
Time elapsed for call 14: 1.7507634162902832 seconds
Time elapsed for call 15: 1.8089756965637207 seconds
Time elapsed for call 16: 1.7215228080749512 seconds
Time elapsed for call 17: 1.631645679473877 seconds
Time elapsed for call 18: 1.7276067733764648 seconds
Time elapsed for call 19: 1.6951384544372559 seconds
Time elapsed for call 20: 1.7005517482757568 seconds
Time elapsed for call 21: 1.6384053230285645 seconds
Time elapsed for call 22: 1.6944351196289062 seconds
Time elapsed for call 23: 1.6808490753173828 seconds
Time elapsed for call 24: 1.7491445541381836 seconds
Time elapsed for call 25: 1.8734047412872314 seconds
Time elapsed for call 26: 1.7570459842681885 seconds
Time elapsed for call 27: 1.852790355682373 seconds
Time elapsed for call 28: 1.6827940940856934 seconds
Time elapsed for call 29: 1.680189609527588 seconds WITH GPU: Time elapsed for call 0: 25.906383991241455 seconds
Time elapsed for call 1: 0.954118013381958 seconds
Time elapsed for call 2: 1.1041779518127441 seconds
Time elapsed for call 3: 0.9762158393859863 seconds
Time elapsed for call 4: 0.9117069244384766 seconds
Time elapsed for call 5: 0.884854793548584 seconds
Time elapsed for call 6: 0.9398739337921143 seconds
Time elapsed for call 7: 0.9885945320129395 seconds
Time elapsed for call 8: 0.8955085277557373 seconds
Time elapsed for call 9: 0.8905599117279053 seconds
Time elapsed for call 10: 0.8585331439971924 seconds
Time elapsed for call 11: 0.959202766418457 seconds
Time elapsed for call 12: 1.0689916610717773 seconds
Time elapsed for call 13: 0.9468050003051758 seconds
Time elapsed for call 14: 0.9171993732452393 seconds
Time elapsed for call 15: 0.8998894691467285 seconds
Time elapsed for call 16: 0.960838794708252 seconds
Time elapsed for call 17: 0.924644947052002 seconds
Time elapsed for call 18: 0.8758440017700195 seconds
Time elapsed for call 19: 0.9010753631591797 seconds
Time elapsed for call 20: 0.915982723236084 seconds
Time elapsed for call 21: 0.979011058807373 seconds
Time elapsed for call 22: 0.9859118461608887 seconds
Time elapsed for call 23: 0.9595963954925537 seconds
Time elapsed for call 24: 0.9682104587554932 seconds
Time elapsed for call 25: 0.8719127178192139 seconds
Time elapsed for call 26: 0.9665031433105469 seconds
Time elapsed for call 27: 0.8821077346801758 seconds
Time elapsed for call 28: 0.8932716846466064 seconds
Time elapsed for call 29: 0.9520418643951416 seconds |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I forgot to @dataclass(frozen=True, eq=True) # for hashability
class foo:
...
@partial(jax.jit, static_argnums=(0, ..possibly more))
def method:
...do stuff In my case this looks like: @dataclass(eq=True, frozen=True)
class LogProbability:
likelihood: Callable # but not really
prior: Callable # but not really
num_batches: int
@partial(jax.jit, static_argnums=(0,1))
def value(
self,
net_apply: Callable,
net_state: Dict,
params: Dict,
x: jax.Array,
y: jax.Array,
is_training: bool = False,
):
likelihood_value, net_state = self.likelihood.value(
net_apply, net_state, params, x=x, y=y, is_training=is_training
)
prior_value = self.prior.value(params)
# stochastic optimization
log_prob_value = self.num_batches * likelihood_value + prior_value
return log_prob_value, net_state
def grad(self):
return jax.grad(self.value, argnums=2, has_aux=True)()
def value_and_grad(self):
return jax.value_and_grad(self.value, argnums=2, has_aux=True)() New iterations with one GPU: Iteration 0: 1.5100152492523193
Iteration 1: 0.00691986083984375
Iteration 2: 0.006659269332885742
Iteration 3: 0.0062770843505859375
Iteration 4: 0.006309032440185547
Iteration 5: 0.006249666213989258
Iteration 6: 0.0063860416412353516
Iteration 7: 0.0062372684478759766
Iteration 8: 0.006523609161376953
Iteration 9: 0.006345272064208984
Iteration 10: 0.006281614303588867
Iteration 11: 0.006283998489379883
Iteration 12: 0.006841182708740234
Iteration 13: 0.00635075569152832
Iteration 14: 0.006237983703613281
Iteration 15: 0.006408214569091797
Iteration 16: 0.006109952926635742
Iteration 17: 0.006323337554931641
Iteration 18: 0.0062220096588134766
Iteration 19: 0.006312370300292969
Iteration 20: 0.0061571598052978516
Iteration 21: 0.006295204162597656
Iteration 22: 0.006172657012939453
Iteration 23: 0.0059735774993896484
Iteration 24: 0.006140232086181641
Iteration 25: 0.006115913391113281
Iteration 26: 0.006545543670654297
Iteration 27: 0.006410121917724609
Iteration 28: 0.006188154220581055
Iteration 29: 0.00625300407409668 |
Beta Was this translation helpful? Give feedback.
I forgot to
jax.jit
thevalue_and_grad
function. To do this with methods within a dataclass you need to do the following:In my case this looks like: