|
| 1 | +# JAX specific implementations will go here |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import numpy as np |
| 6 | +from flax import linen as nn |
| 7 | +import optax |
| 8 | +from typing import Any, Tuple, List, Callable, Optional |
| 9 | +import logging |
| 10 | + |
| 11 | +logging.basicConfig(level=logging.INFO) |
| 12 | + |
| 13 | +# Flexible model using JAX |
| 14 | +class JAXModel(nn.Module): |
| 15 | + features: List[int] |
| 16 | + use_cnn: bool = False |
| 17 | + conv_dim: int = 2 |
| 18 | + dtype: jnp.dtype = jnp.float32 |
| 19 | + activation: Callable = nn.relu |
| 20 | + |
| 21 | + def setup(self): |
| 22 | + if self.use_cnn: |
| 23 | + if self.conv_dim not in [2, 3]: |
| 24 | + raise ValueError(f"Invalid conv_dim: {self.conv_dim}. Must be 2 or 3.") |
| 25 | + kernel_size = (3, 3) if self.conv_dim == 2 else (3, 3, 3) |
| 26 | + self.conv_layers = [nn.Conv(features=feat, kernel_size=kernel_size, padding='SAME', dtype=self.dtype) |
| 27 | + for feat in self.features[:-1]] |
| 28 | + self.dense_layers = [nn.Dense(feat, dtype=self.dtype) for feat in self.features[:-1]] |
| 29 | + self.final_layer = nn.Dense(self.features[-1], dtype=self.dtype) |
| 30 | + |
| 31 | + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| 32 | + if self.use_cnn: |
| 33 | + expected_dim = self.conv_dim + 2 # batch_size, height, width, (depth), channels |
| 34 | + if len(x.shape) != expected_dim: |
| 35 | + raise ValueError(f"Expected input dimension {expected_dim}, got {len(x.shape)}") |
| 36 | + for layer in self.conv_layers: |
| 37 | + x = self.activation(layer(x)) |
| 38 | + x = nn.max_pool(x, window_shape=(2,) * self.conv_dim, strides=(2,) * self.conv_dim) |
| 39 | + x = x.reshape((x.shape[0], -1)) # Flatten the output |
| 40 | + else: |
| 41 | + if len(x.shape) != 2: |
| 42 | + raise ValueError(f"Expected 2D input for DNN, got {len(x.shape)}D") |
| 43 | + for layer in self.dense_layers: |
| 44 | + x = self.activation(layer(x)) |
| 45 | + return self.final_layer(x) |
| 46 | + |
| 47 | +# JAX-based training function with flexible loss and optimizer |
| 48 | +def train_jax_model( |
| 49 | + model: JAXModel, |
| 50 | + params: Any, |
| 51 | + X: jnp.ndarray, |
| 52 | + y: jnp.ndarray, |
| 53 | + loss_fn: Callable = lambda pred, y: jnp.mean((pred - y) ** 2), |
| 54 | + epochs: int = 100, |
| 55 | + patience: int = 20, |
| 56 | + min_delta: float = 1e-6, |
| 57 | + batch_size: int = 32, |
| 58 | + learning_rate: float = 1e-3, |
| 59 | + grad_clip_value: float = 1.0 |
| 60 | +) -> Tuple[Any, float, List[float]]: |
| 61 | + num_samples = X.shape[0] |
| 62 | + num_batches = max(1, int(np.ceil(num_samples / batch_size))) |
| 63 | + total_steps = epochs * num_batches |
| 64 | + |
| 65 | + lr_schedule = optax.warmup_cosine_decay_schedule( |
| 66 | + init_value=learning_rate * 0.1, |
| 67 | + peak_value=learning_rate, |
| 68 | + warmup_steps=min(100, total_steps // 10), |
| 69 | + decay_steps=total_steps, |
| 70 | + end_value=learning_rate * 0.01 |
| 71 | + ) |
| 72 | + |
| 73 | + optimizer = optax.chain( |
| 74 | + optax.clip_by_global_norm(grad_clip_value), |
| 75 | + optax.adam(lr_schedule) |
| 76 | + ) |
| 77 | + opt_state = optimizer.init(params) |
| 78 | + |
| 79 | + @jax.jit |
| 80 | + def update(params: Any, opt_state: Any, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[Any, Any, float, Any]: |
| 81 | + def loss_wrapper(params): |
| 82 | + pred = model.apply({'params': params}, x) |
| 83 | + return loss_fn(pred, y) |
| 84 | + loss, grads = jax.value_and_grad(loss_wrapper)(params) |
| 85 | + updates, opt_state = optimizer.update(grads, opt_state) |
| 86 | + params = optax.apply_updates(params, updates) |
| 87 | + return params, opt_state, loss, grads |
| 88 | + |
| 89 | + best_loss = float('inf') |
| 90 | + best_params = params |
| 91 | + patience_counter = 0 |
| 92 | + training_history = [] |
| 93 | + plateau_threshold = 1e-8 |
| 94 | + plateau_count = 0 |
| 95 | + max_plateau_count = 15 |
| 96 | + |
| 97 | + try: |
| 98 | + for epoch in range(epochs): |
| 99 | + epoch_loss = 0.0 |
| 100 | + for i in range(num_batches): |
| 101 | + start_idx = i * batch_size |
| 102 | + end_idx = min((i + 1) * batch_size, num_samples) |
| 103 | + batch_X = X[start_idx:end_idx] |
| 104 | + batch_y = y[start_idx:end_idx] |
| 105 | + |
| 106 | + # Ensure batch_X and batch_y have consistent shapes |
| 107 | + if batch_X.shape[0] != batch_y.shape[0]: |
| 108 | + min_size = min(batch_X.shape[0], batch_y.shape[0]) |
| 109 | + batch_X = batch_X[:min_size] |
| 110 | + batch_y = batch_y[:min_size] |
| 111 | + |
| 112 | + params, opt_state, batch_loss, grads = update(params, opt_state, batch_X, batch_y) |
| 113 | + |
| 114 | + if not jnp.isfinite(batch_loss): |
| 115 | + logging.warning(f"Non-finite loss detected: {batch_loss}. Skipping this batch.") |
| 116 | + continue |
| 117 | + |
| 118 | + epoch_loss += batch_loss |
| 119 | + |
| 120 | + if num_batches > 0: |
| 121 | + avg_epoch_loss = epoch_loss / num_batches |
| 122 | + else: |
| 123 | + logging.warning("No valid batches in this epoch.") |
| 124 | + continue |
| 125 | + |
| 126 | + training_history.append(avg_epoch_loss) |
| 127 | + |
| 128 | + logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.6f}") |
| 129 | + |
| 130 | + if avg_epoch_loss < best_loss - min_delta: |
| 131 | + best_loss = avg_epoch_loss |
| 132 | + best_params = jax.tree_map(lambda x: x.copy(), params) # Create a copy of the best params |
| 133 | + patience_counter = 0 |
| 134 | + plateau_count = 0 |
| 135 | + logging.info(f"New best loss: {best_loss:.6f}") |
| 136 | + else: |
| 137 | + patience_counter += 1 |
| 138 | + if abs(avg_epoch_loss - best_loss) < plateau_threshold: |
| 139 | + plateau_count += 1 |
| 140 | + logging.info(f"Plateau detected. Count: {plateau_count}") |
| 141 | + else: |
| 142 | + plateau_count = 0 |
| 143 | + |
| 144 | + if patience_counter >= patience: |
| 145 | + logging.info(f"Early stopping due to no improvement for {patience} epochs") |
| 146 | + break |
| 147 | + elif plateau_count >= max_plateau_count: |
| 148 | + logging.info(f"Early stopping due to {max_plateau_count} plateaus") |
| 149 | + break |
| 150 | + |
| 151 | + # Check if loss is decreasing |
| 152 | + if epoch > 0 and avg_epoch_loss > training_history[-2] * 1.1: # 10% tolerance |
| 153 | + logging.warning(f"Loss increased significantly: {training_history[-2]:.6f} -> {avg_epoch_loss:.6f}") |
| 154 | + # Implement learning rate reduction on significant loss increase |
| 155 | + current_lr = lr_schedule(epoch * num_batches) |
| 156 | + new_lr = current_lr * 0.5 |
| 157 | + lr_schedule = optax.exponential_decay( |
| 158 | + init_value=new_lr, |
| 159 | + transition_steps=num_batches, |
| 160 | + decay_rate=0.99 |
| 161 | + ) |
| 162 | + optimizer = optax.chain( |
| 163 | + optax.clip_by_global_norm(grad_clip_value), |
| 164 | + optax.adam(lr_schedule) |
| 165 | + ) |
| 166 | + opt_state = optimizer.init(params) |
| 167 | + logging.info(f"Reduced learning rate to {new_lr:.6f}") |
| 168 | + |
| 169 | + # Monitor gradient norms |
| 170 | + grad_norm = optax.global_norm(jax.tree_map(lambda x: x.astype(jnp.float32), grads)) |
| 171 | + logging.info(f"Gradient norm: {grad_norm:.6f}") |
| 172 | + |
| 173 | + # Implement gradient noise addition |
| 174 | + if grad_norm < 1e-6: |
| 175 | + noise_scale = 1e-6 |
| 176 | + noisy_grads = jax.tree_map(lambda x: x + jax.random.normal(jax.random.PRNGKey(epoch), x.shape) * noise_scale, grads) |
| 177 | + updates, opt_state = optimizer.update(noisy_grads, opt_state) |
| 178 | + params = optax.apply_updates(params, updates) |
| 179 | + logging.info("Added gradient noise due to small gradient norm") |
| 180 | + |
| 181 | + except Exception as e: |
| 182 | + logging.error(f"Error during training: {str(e)}") |
| 183 | + raise |
| 184 | + |
| 185 | + # Ensure consistent parameter shapes |
| 186 | + best_params = jax.tree_map(lambda x: x.astype(jnp.float32), best_params) |
| 187 | + |
| 188 | + logging.info(f"Training completed. Best loss: {best_loss:.6f}") |
| 189 | + return best_params, best_loss, training_history |
| 190 | + |
| 191 | +# Improved batch prediction with better error handling |
| 192 | +@jax.jit |
| 193 | +def batch_predict(params: Any, x: jnp.ndarray, use_cnn: bool = False, conv_dim: int = 2) -> jnp.ndarray: |
| 194 | + try: |
| 195 | + # Validate params structure |
| 196 | + if not isinstance(params, dict): |
| 197 | + raise ValueError("params must be a dictionary") |
| 198 | + |
| 199 | + # Determine the number of features dynamically |
| 200 | + layer_keys = [k for k in params.keys() if k.startswith(('dense_layers_', 'conv_layers_', 'final_dense'))] |
| 201 | + if not layer_keys: |
| 202 | + raise ValueError("No valid layers found in params") |
| 203 | + last_layer = max(layer_keys, key=lambda k: int(k.split('_')[-1]) if '_' in k else float('inf')) |
| 204 | + num_features = params[last_layer]['kernel'].shape[-1] |
| 205 | + |
| 206 | + # Dynamically create model based on params structure |
| 207 | + features = [params[k]['kernel'].shape[-1] for k in sorted(layer_keys) if k != 'final_dense'] |
| 208 | + features.append(num_features) |
| 209 | + model = JAXModel(features=features, use_cnn=use_cnn, conv_dim=conv_dim) |
| 210 | + |
| 211 | + # Ensure input is a JAX array and handle different input shapes |
| 212 | + if not isinstance(x, jnp.ndarray): |
| 213 | + x = jnp.array(x) |
| 214 | + original_shape = x.shape |
| 215 | + if use_cnn: |
| 216 | + expected_dims = conv_dim + 2 # batch, height, width, (depth), channels |
| 217 | + if x.ndim == expected_dims - 1: |
| 218 | + x = x.reshape(1, *x.shape) # Add batch dimension for single image |
| 219 | + elif x.ndim != expected_dims: |
| 220 | + raise ValueError(f"Invalid input shape for CNN. Expected {expected_dims} dimensions, got {x.ndim}. Input shape: {original_shape}") |
| 221 | + else: |
| 222 | + if x.ndim == 1: |
| 223 | + x = x.reshape(1, -1) |
| 224 | + elif x.ndim == 0: |
| 225 | + x = x.reshape(1, 1) |
| 226 | + elif x.ndim != 2: |
| 227 | + raise ValueError(f"Invalid input shape for DNN. Expected 2 dimensions, got {x.ndim}. Input shape: {original_shape}") |
| 228 | + |
| 229 | + # Ensure x has the correct input dimension |
| 230 | + first_layer_key = min(layer_keys, key=lambda k: int(k.split('_')[-1]) if '_' in k else float('inf')) |
| 231 | + expected_input_dim = params[first_layer_key]['kernel'].shape[0] |
| 232 | + if not use_cnn and x.shape[-1] != expected_input_dim: |
| 233 | + raise ValueError(f"Input dimension mismatch. Expected {expected_input_dim}, got {x.shape[-1]}. Input shape: {original_shape}") |
| 234 | + |
| 235 | + # Apply the model |
| 236 | + output = model.apply({'params': params}, x) |
| 237 | + |
| 238 | + # Reshape output to match input shape if necessary |
| 239 | + if len(original_shape) > 2 and not use_cnn: |
| 240 | + output = output.reshape(original_shape[:-1] + (-1,)) |
| 241 | + elif len(original_shape) == 0: |
| 242 | + output = output.squeeze() |
| 243 | + |
| 244 | + logging.info(f"Batch prediction successful. Input shape: {original_shape}, Output shape: {output.shape}") |
| 245 | + return output |
| 246 | + except ValueError as ve: |
| 247 | + logging.error(f"ValueError in batch_predict: {str(ve)}") |
| 248 | + raise |
| 249 | + except Exception as e: |
| 250 | + logging.error(f"Unexpected error in batch_predict: {str(e)}") |
| 251 | + raise RuntimeError(f"Batch prediction failed: {str(e)}") |
| 252 | + |
| 253 | +# Example of using pmap for multi-device computation |
| 254 | +@jax.pmap |
| 255 | +def parallel_train(model: JAXModel, params: Any, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[Any, float]: |
| 256 | + return train_jax_model(model, params, x, y) |
0 commit comments