reinforcement learning - How to use jax.vmap with a tuple of flax TrainStates as input? - Stack Overflow

admin2025-04-17  5

I am setting up a Deep MARL framework and I need to assess my actor policies. Ideally, this would entail using jax.vmap over a tuple of actor flax TrainStates. I have tried the following:

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import optax
import distrax

class PGActor_1(nn.Module):

   @nnpact
   def __call__(self, x):
       action_dim = 4
       activation = nn.tanh

       actor_mean = nn.Dense(128, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0))(x)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
       pi = distrax.Categorical(logits=actor_mean)

    return pi

class PGActor_2(nn.Module):

   @nnpact
   def __call__(self, x):
       action_dim = 2
       activation = nn.tanh

       actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
       pi = distrax.Categorical(logits=actor_mean)

    return pi

state= jnp.zeros((1, 5))

network_1 = PGActor_1()
network_1_init_rng = jax.random.PRNGKey(42)
params_1 = network_1.init(network_1_init_rng, state)

network_2 = PGActor_2()
network_2_init_rng = jax.random.PRNGKey(42)
params_2 = network_2.init(network_2_init_rng, state)

tx = optax.chain(
optax.clip_by_global_norm(1),
optax.adam(lr=1e-3)
)
actor_trainstates= (
 TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_1),             
 TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_2)
 )
pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(actor_trainstates)

but I recieve the following error:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Does anybody have any idea how to make this work?

Thank you in advance.

I am setting up a Deep MARL framework and I need to assess my actor policies. Ideally, this would entail using jax.vmap over a tuple of actor flax TrainStates. I have tried the following:

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import optax
import distrax

class PGActor_1(nn.Module):

   @nn.compact
   def __call__(self, x):
       action_dim = 4
       activation = nn.tanh

       actor_mean = nn.Dense(128, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0))(x)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
       pi = distrax.Categorical(logits=actor_mean)

    return pi

class PGActor_2(nn.Module):

   @nn.compact
   def __call__(self, x):
       action_dim = 2
       activation = nn.tanh

       actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
       pi = distrax.Categorical(logits=actor_mean)

    return pi

state= jnp.zeros((1, 5))

network_1 = PGActor_1()
network_1_init_rng = jax.random.PRNGKey(42)
params_1 = network_1.init(network_1_init_rng, state)

network_2 = PGActor_2()
network_2_init_rng = jax.random.PRNGKey(42)
params_2 = network_2.init(network_2_init_rng, state)

tx = optax.chain(
optax.clip_by_global_norm(1),
optax.adam(lr=1e-3)
)
actor_trainstates= (
 TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_1),             
 TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_2)
 )
pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(actor_trainstates)

but I recieve the following error:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Does anybody have any idea how to make this work?

Thank you in advance.

Share edited Feb 4 at 9:56 amavrits asked Feb 1 at 13:23 amavritsamavrits 475 bronze badges 7
  • The answer will depend on the details of how actor_trainstates is defined – it would be helpful if you could edit your question to add a minimal reproducible example to take out the guesswork. – jakevdp Commented Feb 1 at 18:41
  • Thanks for your suggestion. I have edited in some specifics of my appplication. The "network" is simply an MLP flax nn.Module. – amavrits Commented Feb 1 at 19:17
  • Jax only supports "structs of arrays" aka PyTrees. You cannot "vmap" over members of a PyTree such as your tuple, which would represent an "array of structs". The solution here is likely to move the abstraction of the multiple model states into a new "batch like" axis of params. So "vmap" over the model instead and introduce a new axis to params, which represents the different model states. I hope this makes sense! – Axel Donath Commented Feb 3 at 16:04
  • @AxelDonath is correct – you cannot vmap over tuples or Python sequences, only over arrays. It would be easier to give a full answer to your question if your code snippet defined all relevant variables. – jakevdp Commented Feb 3 at 21:20
  • @jakevdp Thank your for your reply. I have made some additions in the snippet, which hopefully clear up the situation. Please let me know if this is sufficient. – amavrits Commented Feb 4 at 9:42
 |  Show 2 more comments

1 Answer 1

Reset to default 1

This is quite similar to other questions (e.g. Jax - vmap over batch of dataclasses). The key point is that JAX transformations like vmap require data in a struct of arrays pattern, whereas you are using an array of structs pattern.

To work directly with an array of structs pattern in JAX, you can use Python's built-in map function – due to JAX's asynchronous dispatch, the resulting operations will be executed in parallel where possible:

pis = map(lambda x: x.apply_fn(x.params, state), actor_trainstates)

However, this doesn't take advantage of the automatic vectorization done by vmap. In order to do this, you can convert your data from an array of structs to a struct of arrays, although this requires that all entries have the same structure.

For compatible cases, the solution would look something like this, however it errors for your data:

train_states_soa = jax.tree.map(lambda *args: jnp.stack(args), *actor_trainstates)
pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(train_states_soa)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-da904fa40b9c> in <cell line: 0>()
----> 1 train_states_soa = jax.tree.map(lambda *args: jnp.stack(args), *actor_trainstates)

ValueError: Dict key mismatch; expected keys: ['Dense_0', 'Dense_1', 'Dense_2']

The problem is that your two train states do not have matching structure, and so they cannot be transformed into a single struct of arrays. You can see the difference in structure by inspecting the params:

print(actor_trainstates[0].params['params'].keys())  # dict_keys(['Dense_0', 'Dense_1', 'Dense_2'])
print(actor_trainstates[1].params['params'].keys())  # dict_keys(['Dense_0', 'Dense_1'])

There is no way to use vmap in a context where your inputs have different structure, so you'll either have to change the problem to ensure the same structure, or stick with the map approach.

转载请注明原文地址:http://anycun.com/QandA/1744827616a88173.html