Skip to content

Commit 7702e02

Browse files
Release version 0.0.3: Added neuromorphic computing features and input validation improvements
1 parent adec74c commit 7702e02

File tree

5 files changed

+402
-34
lines changed

5 files changed

+402
-34
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# NeuroFlex Features Documentation
2+
3+
## Table of Contents
4+
5+
1. [Introduction](#introduction)
6+
2. [Core Features](#core-features)
7+
3. [Advanced Functionalities](#advanced-functionalities)
8+
3.1. [Quantum Neural Network](#quantum-neural-network)
9+
3.2. [Reinforcement Learning](#reinforcement-learning)
10+
3.3. [Cognitive Architecture](#cognitive-architecture)
11+
3.4. [Neuromorphic Computing](#neuromorphic-computing)
12+
4. [Integrations](#integrations)
13+
4.1. [AlphaFold Integration](#alphafold-integration)
14+
4.2. [JAX, TensorFlow, and PyTorch Support](#jax-tensorflow-and-pytorch-support)
15+
5. [Natural Language Processing](#natural-language-processing)
16+
6. [Performance and Optimization](#performance-and-optimization)
17+
7. [Safety Features](#safety-features)
18+
8. [Usage Examples](#usage-examples)
19+
9. [Future Developments](#future-developments)
20+
21+
## Introduction
22+
23+
NeuroFlex is a cutting-edge, versatile machine learning framework designed to push the boundaries of artificial intelligence. It combines traditional deep learning techniques with advanced quantum computing, reinforcement learning, cognitive architectures, and neuromorphic computing. This documentation provides a comprehensive overview of NeuroFlex's features, capabilities, and integrations. NeuroFlex supports multiple Python versions, ensuring compatibility across various development environments and enhancing its versatility for researchers and practitioners alike.
24+
25+
## Core Features
26+
27+
- **Advanced Neural Network Architectures**: Supports a wide range of neural networks, including CNNs, RNNs, LSTMs, GANs, and Spiking Neural Networks, providing flexibility for diverse machine learning tasks.
28+
- **Multi-Backend Support**: Seamlessly integrates with JAX, TensorFlow, and PyTorch, allowing users to leverage the strengths of each framework.
29+
- **Quantum Computing Integration**: Incorporates quantum neural networks for enhanced computational capabilities and exploration of quantum machine learning algorithms.
30+
- **Reinforcement Learning**: Robust support for RL algorithms and environments, enabling the development of intelligent agents for complex decision-making tasks.
31+
- **Advanced Natural Language Processing**: Includes tokenization, grammar correction, and state-of-the-art language models for sophisticated text processing and generation.
32+
- **Bioinformatics Tools**: Integrates with AlphaFold and other bioinformatics libraries, facilitating advanced protein structure prediction and analysis.
33+
- **Self-Curing Algorithms**: Implements adaptive learning and self-improvement mechanisms for enhanced model robustness and reliability.
34+
- **Fairness and Ethical AI**: Incorporates fairness constraints and ethical considerations in model training, promoting responsible AI development.
35+
- **Brain-Computer Interface (BCI) Support**: Provides functionality for processing and analyzing brain signals, enabling the development of advanced BCI applications.
36+
- **Cognitive Architecture**: Implements sophisticated cognitive models that simulate human-like reasoning and decision-making processes.
37+
- **Neuromorphic Computing**: Implements spiking neural networks for energy-efficient, brain-inspired computing.
38+
39+
## Advanced Functionalities
40+
41+
### Quantum Neural Network
42+
43+
NeuroFlex integrates quantum computing capabilities through its QuantumNeuralNetwork module. This hybrid quantum-classical approach leverages the power of quantum circuits to enhance computational capabilities. Key features include:
44+
45+
- Variational quantum circuits with customizable number of qubits and layers
46+
- Hybrid quantum-classical computations using JAX for seamless integration
47+
- Adaptive quantum circuit execution with error handling and classical fallback
48+
49+
### Reinforcement Learning
50+
51+
The framework provides robust support for reinforcement learning, enabling the development of intelligent agents that learn from interaction with their environment. Notable features include:
52+
53+
- Flexible RL agent architecture with support for various algorithms (e.g., DQN, Policy Gradient)
54+
- Integration with popular RL environments (e.g., OpenAI Gym)
55+
- Advanced training utilities including replay buffers, epsilon-greedy exploration, and learning rate scheduling
56+
57+
### Cognitive Architecture and Brain-Computer Interface (BCI)
58+
59+
NeuroFlex implements an advanced cognitive architecture that simulates complex cognitive processes, bridging the gap between traditional neural networks and human-like reasoning. This architecture is further enhanced with Brain-Computer Interface (BCI) capabilities, allowing for direct interaction between neural systems and external devices. Key aspects include:
60+
61+
- Multi-layer cognitive processing pipeline with advanced neural network architectures (CNN, RNN, LSTM, GAN)
62+
- Simulated attention mechanisms, working memory, and metacognition components
63+
- Integration of decision-making processes and adaptive learning algorithms
64+
- BCI functionality for real-time neural signal processing and interpretation
65+
- Advanced feature extraction techniques for BCI, including wavelet transforms and adaptive filtering
66+
- Cognitive state estimation and intent decoding for intuitive human-machine interaction
67+
- Seamless integration of cognitive models with quantum computing modules for enhanced problem-solving capabilities
68+
69+
### Neuromorphic Computing
70+
71+
NeuroFlex now includes advanced neuromorphic computing capabilities through its SpikingNeuralNetwork module. This biologically-inspired approach mimics the behavior of neurons in the brain, offering energy-efficient and highly parallel computation. Key features include:
72+
73+
- Customizable spiking neural network architecture with flexible neuron counts per layer
74+
- Biologically plausible neuron models with adjustable threshold, reset potential, and leak factor
75+
- Input validation and automatic reshaping for robust handling of various input formats
76+
- Support for both 1D and 2D input tensors, with automatic adjustment for batch processing
77+
- Efficient implementation using JAX for high-performance computing
78+
- Customizable activation functions and spike generation mechanisms
79+
- Integration with other NeuroFlex modules for hybrid AI systems
80+
81+
## Integrations
82+
83+
[... Rest of the content remains unchanged ...]
84+
85+
## Usage Examples
86+
87+
[... Previous examples remain unchanged ...]
88+
89+
### Neuromorphic Computing with Spiking Neural Networks
90+
91+
```python
92+
from NeuroFlex.neuromorphic_computing import SpikingNeuralNetwork
93+
import jax.numpy as jnp
94+
95+
# Create a spiking neural network
96+
snn = SpikingNeuralNetwork(num_neurons=[64, 32, 10])
97+
98+
# Example input (can be 1D or 2D)
99+
input_data = jnp.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
100+
101+
# Initialize the network
102+
rng = jax.random.PRNGKey(0)
103+
params = snn.init(rng, input_data)
104+
105+
# Run the network
106+
output, membrane_potentials = snn.apply(params, input_data)
107+
print("SNN output:", output)
108+
print("Membrane potentials:", membrane_potentials)
109+
```
110+
111+
These examples demonstrate some of the key features of the NeuroFlex framework. For more detailed usage and advanced features, please refer to the specific module documentation.
112+
113+
## Future Developments
114+
115+
[... Rest of the content remains unchanged ...]

setup.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
setup(
44
name="neuroflex",
5-
version="0.0.1",
5+
version="0.0.3",
66
author="kasinadhsarma",
77
author_email="kasinadhsarma@gmail.com",
88
description="An advanced neural network framework with interpretability, generalization, robustness, and fairness features",
99
long_description=open("README.md").read(),
1010
long_description_content_type="text/markdown",
1111
url="https://github.com/VishwamAI/neuroflex",
12-
packages=find_packages(),
12+
packages=find_packages(where='src'),
13+
package_dir={'': 'src'},
1314
classifiers=[
1415
"Development Status :: 3 - Alpha",
1516
"Intended Audience :: Developers",
@@ -22,37 +23,45 @@
2223
],
2324
python_requires=">=3.8",
2425
install_requires=[
25-
"jax==0.4.10",
26-
"jaxlib==0.4.10",
27-
"ml_dtypes==0.2.0",
28-
"flax==0.7.2",
29-
"optax==0.1.7",
30-
"tensorflow-cpu==2.16.1",
31-
"keras==3.5.0",
32-
"gym==0.26.2",
33-
"pytest==7.4.0",
34-
"flake8==6.0.0",
35-
"numpy==1.24.3",
36-
"scipy==1.10.1",
37-
"matplotlib==3.7.1",
38-
"aif360==0.5.0",
39-
"packaging==23.1",
40-
"gast==0.6.0",
41-
"wrapt==1.16.0",
42-
"pennylane==0.32.0",
43-
"ibm-watson-machine-learning>=1.0.257",
44-
"scikit-learn>=1.2.2",
45-
"pandas>=2.0.2",
46-
"adversarial-robustness-toolbox>=1.15.0",
47-
"lale>=0.7.0",
48-
"qutip>=4.7.1",
49-
"pyquil>=3.5.4",
50-
"qiskit>=0.43.0",
51-
"biopython>=1.81",
52-
"scikit-bio>=0.5.8",
53-
"ete3>=3.1.2",
54-
"xarray>=2023.5.0",
55-
"torch>=2.0.1",
56-
"alphafold==2.0.0",
26+
"jax>=0.3.0",
27+
"jaxlib>=0.3.0",
28+
"ml_dtypes",
29+
"flax>=0.6.0",
30+
"optax",
31+
"tensorflow-cpu",
32+
"keras",
33+
"gym",
34+
"pytest",
35+
"flake8",
36+
"numpy",
37+
"scipy",
38+
"matplotlib",
39+
"aif360",
40+
"packaging",
41+
"gast",
42+
"wrapt",
43+
"pennylane",
44+
"ibm-watson-machine-learning",
45+
"scikit-learn",
46+
"pandas",
47+
"adversarial-robustness-toolbox",
48+
"lale",
49+
"qutip",
50+
"pyquil",
51+
"qiskit",
52+
"biopython",
53+
"scikit-bio",
54+
"ete3",
55+
"xarray",
56+
"torch",
57+
# Removed direct GitHub dependency: "alphafold @ git+https://github.com/google-deepmind/alphafold.git"
58+
# If needed, install alphafold separately or specify a PyPI-compatible version
59+
"shap",
5760
],
61+
extras_require={
62+
'dev': [
63+
'pytest',
64+
'flake8',
65+
],
66+
},
5867
)

src/NeuroFlex/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
__version__ = "0.0.3"
2+
13
# Import main components
24
from .advanced_thinking import NeuroFlex, data_augmentation, create_train_state, select_action, adversarial_training
35
from .machinelearning import NeuroFlexClassifier
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import flax.linen as nn
4+
import optax
5+
from typing import List, Tuple, Callable, Optional
6+
import logging
7+
8+
def spiking_neuron(x, membrane_potential, threshold=1.0, reset_potential=0.0, leak_factor=0.9):
9+
new_membrane_potential = jnp.add(leak_factor * membrane_potential, x)
10+
spike = jnp.where(new_membrane_potential >= threshold, 1.0, 0.0)
11+
new_membrane_potential = jnp.where(spike == 1.0, reset_potential, new_membrane_potential)
12+
return spike, new_membrane_potential
13+
14+
class SpikingNeuralNetwork(nn.Module):
15+
num_neurons: List[int]
16+
activation: Callable = nn.relu
17+
spike_function: Callable = lambda x: jnp.where(x > 0, 1.0, 0.0)
18+
threshold: float = 1.0
19+
reset_potential: float = 0.0
20+
leak_factor: float = 0.9
21+
22+
@nn.compact
23+
def __call__(self, inputs, membrane_potentials=None):
24+
logging.debug(f"Input shape: {inputs.shape}")
25+
x = inputs
26+
27+
# Input validation and reshaping
28+
if len(inputs.shape) == 1:
29+
x = jnp.expand_dims(x, axis=0)
30+
elif len(inputs.shape) > 2:
31+
x = jnp.reshape(x, (-1, x.shape[-1]))
32+
33+
if x.shape[1] != self.num_neurons[0]:
34+
raise ValueError(f"Input shape {x.shape} does not match first layer neurons {self.num_neurons[0]}")
35+
36+
if membrane_potentials is None:
37+
membrane_potentials = [jnp.zeros((x.shape[0], num_neuron)) for num_neuron in self.num_neurons]
38+
else:
39+
if len(membrane_potentials) != len(self.num_neurons):
40+
raise ValueError(f"Expected {len(self.num_neurons)} membrane potentials, got {len(membrane_potentials)}")
41+
membrane_potentials = [jnp.broadcast_to(mp, (x.shape[0], mp.shape[-1])) for mp in membrane_potentials]
42+
43+
logging.debug(f"Adjusted input shape: {x.shape}")
44+
logging.debug(f"Adjusted membrane potentials shapes: {[mp.shape for mp in membrane_potentials]}")
45+
46+
new_membrane_potentials = []
47+
for i, (num_neuron, membrane_potential) in enumerate(zip(self.num_neurons, membrane_potentials)):
48+
logging.debug(f"Layer {i} - Input shape: {x.shape}, Membrane potential shape: {membrane_potential.shape}")
49+
50+
spiking_layer = jax.vmap(lambda x, mp: spiking_neuron(x, mp, self.threshold, self.reset_potential, self.leak_factor),
51+
in_axes=(0, 0), out_axes=0)
52+
spikes, new_membrane_potential = spiking_layer(x, membrane_potential)
53+
54+
logging.debug(f"Layer {i} - Spikes shape: {spikes.shape}, New membrane potential shape: {new_membrane_potential.shape}")
55+
56+
x = self.activation(spikes)
57+
new_membrane_potentials.append(new_membrane_potential)
58+
59+
# Adjust x for the next layer
60+
if i < len(self.num_neurons) - 1:
61+
x = nn.Dense(self.num_neurons[i+1])(x)
62+
63+
logging.debug(f"Final output shape: {x.shape}")
64+
return self.spike_function(x), new_membrane_potentials
65+
66+
class NeuromorphicComputing(nn.Module):
67+
num_neurons: List[int]
68+
threshold: float = 1.0
69+
reset_potential: float = 0.0
70+
leak_factor: float = 0.9
71+
72+
def setup(self):
73+
self.model = SpikingNeuralNetwork(num_neurons=self.num_neurons,
74+
threshold=self.threshold,
75+
reset_potential=self.reset_potential,
76+
leak_factor=self.leak_factor)
77+
logging.info(f"Initialized NeuromorphicComputing with {len(self.num_neurons)} layers")
78+
79+
def __call__(self, inputs, membrane_potentials=None):
80+
return self.model(inputs, membrane_potentials)
81+
82+
def init_model(self, rng, input_shape):
83+
dummy_input = jnp.zeros(input_shape)
84+
membrane_potentials = [jnp.zeros(input_shape[:-1] + (n,)) for n in self.num_neurons]
85+
# Ensure consistent shapes between inputs and membrane potentials
86+
if dummy_input.shape[1] != membrane_potentials[0].shape[1]:
87+
dummy_input = jnp.reshape(dummy_input, (-1, membrane_potentials[0].shape[1]))
88+
return self.init(rng, dummy_input, membrane_potentials)
89+
90+
@jax.jit
91+
def forward(self, params, inputs, membrane_potentials):
92+
return self.apply(params, inputs, membrane_potentials)
93+
94+
def train_step(self, params, inputs, targets, membrane_potentials, optimizer):
95+
def loss_fn(params):
96+
outputs, new_membrane_potentials = self.forward(params, inputs, membrane_potentials)
97+
return jnp.mean((outputs - targets) ** 2), new_membrane_potentials
98+
99+
(loss, new_membrane_potentials), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
100+
updates, optimizer_state = optimizer.update(grads, optimizer.state)
101+
params = optax.apply_updates(params, updates)
102+
optimizer = optimizer.replace(state=optimizer_state)
103+
return params, loss, new_membrane_potentials, optimizer
104+
105+
@staticmethod
106+
def handle_error(e: Exception) -> None:
107+
logging.error(f"Error in NeuromorphicComputing: {str(e)}")
108+
if isinstance(e, jax.errors.JAXException):
109+
logging.error("JAX-specific error occurred. Check JAX configuration and input shapes.")
110+
elif isinstance(e, ValueError):
111+
logging.error("Value error occurred. Check input data and model parameters.")
112+
else:
113+
logging.error("Unexpected error occurred. Please review the stack trace for more information.")
114+
raise
115+
116+
def create_neuromorphic_model(num_neurons: List[int]) -> NeuromorphicComputing:
117+
return NeuromorphicComputing(num_neurons=num_neurons)

0 commit comments

Comments
 (0)