Skip to content

Commit

Permalink
BoringMC: Change the SOT to be JAX rather than TensorFlow.
Browse files Browse the repository at this point in the history
This relies heavily on TensorFlow's experimental numpy API, so `tf.experimental.numpy.experimental_enable_numpy_behavior()` must be called before using FunMC from TF.

PiperOrigin-RevId: 599007999
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jan 17, 2024
1 parent 84d0c6a commit 4c2e693
Show file tree
Hide file tree
Showing 20 changed files with 2,647 additions and 2,121 deletions.
9 changes: 6 additions & 3 deletions spinoffs/fun_mc/fun_mc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# Description:
# Functional MC API.

# Placeholder: py_test
# [internal] load pytype.bzl (pytype_library)
# Placeholder: py_test

licenses(["notice"])

Expand Down Expand Up @@ -55,7 +55,7 @@ py_library(
name = "backend",
srcs = ["backend.py"],
deps = [
"//fun_mc/dynamic/backend_tensorflow:backend",
"//fun_mc/dynamic/backend_jax:backend",
],
)

Expand Down Expand Up @@ -108,6 +108,7 @@ py_test(
shard_count = 8,
deps = [
":fun_mc",
":prefab",
":test_util",
# absl/testing:parameterized dep,
# scipy dep,
Expand All @@ -134,6 +135,7 @@ py_test(
shard_count = 2,
deps = [
":fun_mc",
":malt",
":test_util",
# jax dep,
# tensorflow dep,
Expand All @@ -159,6 +161,7 @@ py_test(
shard_count = 2,
deps = [
":fun_mc",
":sga_hmc",
":test_util",
# jax dep,
# tensorflow dep,
Expand All @@ -185,6 +188,7 @@ py_test(
shard_count = 2,
deps = [
":fun_mc",
":prefab",
":test_util",
# jax dep,
# tensorflow dep,
Expand All @@ -200,7 +204,6 @@ py_library(
deps = [
":backend",
":fun_mc_lib",
# tensorflow_probability dep,
],
)

Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# ============================================================================
"""Default backend implementation."""

from fun_mc.dynamic.backend_tensorflow.backend import * # pylint: disable=wildcard-import
from fun_mc.dynamic.backend_jax.backend import * # pylint: disable=wildcard-import
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from fun_mc import using_tensorflow as fun_mc
from absl.testing import absltest

tf.experimental.numpy.experimental_enable_numpy_behavior()


class TensorFlowIntegrationTest(absltest.TestCase):

Expand Down
10 changes: 0 additions & 10 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,10 @@ py_library(
],
)

py_library(
name = "tf_on_jax",
srcs = ["tf_on_jax.py"],
deps = [
# jax dep,
# jax:stax dep,
],
)

py_library(
name = "backend",
srcs = ["backend.py"],
deps = [
":tf_on_jax",
":util",
# tensorflow_probability/substrates:jax dep,
],
Expand Down
9 changes: 5 additions & 4 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
# ============================================================================
"""JAX backend."""

from fun_mc.dynamic.backend_jax import tf_on_jax
import jax
import jax.numpy as jnp

from fun_mc.dynamic.backend_jax import util
from tensorflow_probability.substrates import jax as tfp
from tensorflow_probability.substrates.jax.internal import distribute_lib
from tensorflow_probability.substrates.jax.internal import prefer_static

tf = tf_on_jax.tf

__all__ = [
'distribute_lib',
'prefer_static',
'tf',
'jax',
'jnp',
'tfp',
'util',
]
208 changes: 0 additions & 208 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py

This file was deleted.

42 changes: 41 additions & 1 deletion spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,56 @@
# ============================================================================
"""TensorFlow backend."""

import types
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import distribute_lib
from tensorflow_probability.python.internal import prefer_static
from fun_mc.dynamic.backend_tensorflow import util

tnp = tf.experimental.numpy

_lax = types.ModuleType('lax')
_lax.cond = tf.cond
_lax.stop_gradient = tf.stop_gradient

_nn = types.ModuleType('nn')
_nn.softmax = tf.nn.softmax
_nn.one_hot = tf.one_hot


class _ShapeDtypeStruct:
pass


jax = types.ModuleType('jax')
jax.ShapeDtypeStruct = _ShapeDtypeStruct
jax.jit = tf.function
jax.lax = _lax
jax.custom_gradient = tf.custom_gradient
jax.nn = _nn


class _JNP(types.ModuleType):

def __getattr__(self, name):
return getattr(tnp, name)


jnp = _JNP('numpy')
jnp.dtype = tf.DType
# These are technically provided by TensorFlow, but only after numpy mode is
# enabled.
jnp.ndarray = tf.Tensor
jnp.float32 = tf.float32
jnp.float64 = tf.float64


__all__ = [
'distribute_lib',
'prefer_static',
'tf',
'jnp',
'jax',
'tfp',
'util',
]
Loading

0 comments on commit 4c2e693

Please sign in to comment.