"""Backend detection and configuration utilities."""
import os
import warnings
[docs]
def detect_available_backends():
"""
Detect which Keras backends are available in the current environment.
Returns:
list: List of available backend names in order of preference
"""
available_backends = []
# Check for JAX
try:
import jax
import jaxlib
available_backends.append('jax')
except ImportError:
pass
# Check for PyTorch
try:
import torch
available_backends.append('torch')
except ImportError:
pass
# Check for TensorFlow
try:
import tensorflow
available_backends.append('tensorflow')
except ImportError:
pass
return available_backends
[docs]
def get_best_backend():
"""
Get the best available backend for the current environment.
Returns:
str: Name of the best available backend
Raises:
RuntimeError: If no backends are available
"""
available = detect_available_backends()
if not available:
raise RuntimeError(
"No Keras backends available. Please install at least one of: "
"jax, torch, or tensorflow"
)
# Preference order: PyTorch > JAX > TensorFlow (PyTorch is more stable for our use case)
preference_order = ['torch', 'jax', 'tensorflow']
for backend in preference_order:
if backend in available:
return backend
# Fallback to first available
return available[0]
[docs]
def setup_test_backend():
"""
Set up the best available backend for testing.
Returns:
str: Name of the configured backend
"""
try:
backend = configure_backend()
backend_specific_config(backend)
print(f"Using Keras backend: {backend}")
return backend
except Exception as e:
print(f"Warning: Failed to configure backend: {e}")
# Try to use whatever is available
available = detect_available_backends()
if available:
backend = available[0]
os.environ["KERAS_BACKEND"] = backend
backend_specific_config(backend)
print(f"Fallback to: {backend}")
return backend
return None
[docs]
def backend_specific_config(backend_name):
"""
Apply backend-specific configurations.
Args:
backend_name (str): Name of the backend to configure
"""
if backend_name == 'jax':
try:
import jax
import jax.numpy as jnp
# Read environment variables for JAX configuration
jax_platforms = os.environ.get('JAX_PLATFORMS', 'cpu')
jax_enable_x64 = os.environ.get('JAX_ENABLE_X64', 'true').lower() == 'true'
jax_disable_jit = os.environ.get('JAX_DISABLE_JIT', 'true').lower() == 'true'
# Apply JAX configurations
jax.config.update('jax_platforms', jax_platforms)
jax.config.update("jax_enable_x64", jax_enable_x64)
jax.config.update('jax_disable_jit', jax_disable_jit)
# Set memory preallocation to avoid memory issues
if 'XLA_PYTHON_CLIENT_PREALLOCATE' not in os.environ:
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
except (ImportError, Exception):
pass
elif backend_name == 'torch':
try:
import torch
# Set default tensor type to float32
torch.set_default_dtype(torch.float32)
# Use CPU if CUDA is not available
if not torch.cuda.is_available():
torch.set_default_device('cpu')
except (ImportError, Exception):
pass
elif backend_name == 'tensorflow':
try:
import tensorflow as tf
# Suppress TensorFlow warnings
tf.get_logger().setLevel('ERROR')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Use CPU if GPU is not available
tf.config.set_visible_devices([], 'GPU')
except (ImportError, Exception):
pass
# Auto-configure on import only if explicitly requested
# Only auto-configure if this module is run directly
if __name__ == "__main__":
auto_configure()