Skip to content

Jax

Jax提供自动微分和 JIT 编译功能,不过@jit只能用于静态的func!否则参考‘有状态计算

jnp操作类似np, 而PyTree类似dict,具体区别包括 数组更新不可in_place更新 等

与TF基本通用

安装

pip install -U jax
pip install -U flax

一些基本操作:Basics_Jax.ipynb

如何进行NN训练:Basics_Jax_NN.ipynb

参考

官方教程 https://jax.ac.cn/en/latest/

安装报错1 https://blog.csdn.net/duoyasong5907/article/details/142324484

纯Python函数的回调!!  https://jax.ac.cn/en/latest/_autosummary/jax.pure_callback.html

教程  https://www.bilibili.com/video/BV1TppqeAEPN/?spm_id_from=333.1007.0.0