You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Dispatching parallel data input dynamically is an important way to handle Embedding model parallelism in the recommender model. Please, tell me how to support this feature.
Here is a snapshot from Meta paper :
Please:
Check for duplicate requests.
Describe your goal, and if possible provide a code snippet with a motivating example.
This discussion was converted from issue #22296 on July 08, 2024 11:10.
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
Menu
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
For example
input = [0,1,2,3,4,5,6,7]
The input is regarded as one training sample and segmented in the batch dimension with shard_map. The 'i' is the dimension name.
device 0:
input = [0,1,2,3]
device 1:
input = [4,5,6,7]
And I need to ppermute the input between different devices.
device 0:
lax.ppermute(input[0:1], 'i', [(0,1),(1,0)])
device 1:
lax.ppermute(input[1:3], 'i', [(0,1),(1,0)])
Where jax.disable_jit() is useless.
Dispatching parallel data input dynamically is an important way to handle Embedding model parallelism in the recommender model. Please, tell me how to support this feature.
Here is a snapshot from Meta paper :

Please:
Beta Was this translation helpful? Give feedback.
All reactions