Reverse Dependencies of optax
The following projects have a declared dependency on optax:
- jax-relax — JAX-based Recourse Explanation Library
- jax-verify — A library for neural network verification.
- jaxDiversity — jax implementation for metalearning neuronal diversity
- jaxex — A tool for creating science experiments in jax, torch, brax, etc
- jaxfss — JAX/Flax implementation of finite-size scaling
- jaxmarl — Multi-Agent Reinforcement Learning with JAX
- jaxns — Nested Sampling in JAX
- jaxonmodels — JAX models for deep learning
- jaxparrow — Computes the inversion of the cyclogeostrophic balance based on a variational formulation approach, using JAX
- jaxrie — Riemannian JAX
- jaxsnn — jaxsnn is an event-based approach to machine-learning-inspired training and simulation of SNNs, including support for neuromorphic backends (BrainScaleS-2).
- jinns — Physics Informed Neural Network with JAX
- jo3mnist — no summary
- jraph — Jraph: A library for Graph Neural Networks in Jax
- jumanji — A diverse suite of scalable reinforcement learning environments in JAX
- juxtapose — no summary
- jvt — Vision transformers with JAX & Flax
- kfac-jax — A Jax package for approximate curvature estimation and optimization using KFAC.
- kira_llm — That's right, I'm Kira ✍️
- lagrangebench — LagrangeBench: A Lagrangian Fluid Mechanics Benchmarking Suite
- last-asr — The LAttice-based Speech Transducer (LAST) library
- lazyqml — LazyQML benchmarking utility to test quantum machine learning models.
- learned-optimization — Train learned optimizers in Jax.
- levanter — Scalable Training for Foundation Models with Named Tensors and JAX
- lightcurver — A thorough structure for precise photometry and deconvolution of time series of wide field images.
- long-range-models — Simple Flax implementations of long-range sequence models
- matsim-tools — MATSim Agent-Based Transportation Simulation Framework - official python tools
- maxspin — Estimate spatial information in spatial -omics datasets.
- meent — no summary
- minimax-lib — Efficient baselines for autocurricula in JAX
- mlp-gpt-jax — MLP GPT - Jax
- mmpdenet — no summary
- modularbayes — Modular Bayesian Inference.
- MOGPJax — Didactic multi-output Gaussian processes in Jax.
- momaland — A standard API for Multi-Objective Multi-Agent Decision making and a diverse set of reference environments.
- moss-rl — A Python library for Reinforcement Learning.
- muax — A library written in Jax that provides help for using DeepMind's mctx on gym-style environments.
- mw-adapter-transformers — A friendly fork of HuggingFace's Transformers, adding Adapters to PyTorch language models
- nadl — Nasy's Deep Learning Toolkit
- nanodl — A Jax-based library for designing and training transformer models from scratch.
- netket — Netket : Machine Learning toolbox for many-body quantum systems.
- nndp — Dynamic Programming using Neural Networks
- nnx — no summary
- npiv — Nonparametric IV Toolbox for Python
- numpyro — Pyro PPL on NumPy
- optax-adan — An implementation of adan optimization algorithm for optax.
- optax-swag — Stochastic Weight Averaging for Optax
- ott-jax — Optimal Transport Tools in JAX.
- PaLM-jax — PaLM: Scaling Language Modeling with Pathways - Jax
- paltax — Strong lensing package using jax
- par-segmentation — Cell cortex segmentation and quantification in C. elegans PAR protein images
- penzai — Penzai: A JAX research toolkit for building, editing, and visualizing neural networks.
- phlash — Bayesian inference of population size history from recombining sequence data.
- praxis — Functionalities such as a layers for building neural networks in Jax.
- precondition-opt — Preconditioning optimizers.
- prfr — Probabilitic random forest regression algorithm
- probabilistic-reconciliation — Probabilistic reconciliation of time series forecasts
- probdiffeq — Probabilistic numerical solvers for differential equations
- progen-transformer — Protein Generation (ProGen)
- prophetverse — no summary
- protes — Method PROTES (PRobabilistic Optimizer with TEnsor Sampling) for derivative-free optimization of the multidimensional arrays and discretized multivariate functions based on the tensor train (TT) format
- pybefit — Probabilistic inference for models of behaviour
- pycollimator — Collimator.ai core simulation engine and API client
- pyRDDLGym-jax — pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX.
- qadence — Pasqal interface for circuit-based quantum computing SDKs
- qax — A JAX transform for writing things which pretend to be tensors
- qdax — A Python Library for Quality-Diversity and NeuroEvolution
- quask — A software framework to speed up the research in quantum machine learning
- qudit-sim — Multi-qudit system simulation and analysis
- ramsey — Probabilistic deep learning using JAX
- redco — no summary
- reinforced-lib — Reinforcement learning library
- sbijax — Simulation-based inference in JAX
- sbx-rl — Jax version of Stable Baselines, implementations of reinforcement learning algorithms.
- scENVI — Integration of scRNA-seq and spatial transcriptomics data
- scKinetics — Biological prior guided single-cell kinetics inference.
- scvi-tools — Deep probabilistic analysis of single-cell omics data.
- segnn-jax — Steerable E(3) GNN in jax
- sentinex — Sentinex: A high level interface aimed towards rapid prototyping and intuitive workflow for JAX.
- sequel-core — A Continual Learning Framework for both Jax and PyTorch.
- sgGWR — no summary
- SGMCMCJax — SGMCMC samplers in JAX
- sibylla — A Jax package for Gradient Descent Image Reconstruction
- skrl — Modular and flexible library for reinforcement learning on PyTorch and JAX
- slax — An SNN/RNN training package in JAX
- snaxlib — A simple deep learning library for JAX.
- sparsenn — no summary
- spyx — Spyx: SNNs in JAX
- stadion — Causal Modeling with Stationary Diffusions
- surjectors — Surjection layers for density estimation with normalizing flows
- teneva-ht-jax — Compact implementation of basic operations in the Hierarchical Tucker (HT) format for approximation and sampling from multidimensional arrays and multivariate functions
- tensorwrap — TensorWrap: A high level TensorFlow wrapper for JAX.
- tfx — TensorFlow Extended (TFX) is a TensorFlow-based general-purpose machine learning platform implemented at Google.
- tidy3d — A fast FDTD solver
- tinygp — The tiniest of Gaussian Process libraries
- tjax — Tools for JAX.
- TorchOpt — An efficient library for differentiable optimization for PyTorch.
- totypes — Custom datatypes useful in a topology optimization context
- transformers — State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
- treex — no summary