Skip to content

Conversation

@sararb
Copy link
Contributor

@sararb sararb commented Jun 27, 2023

Fixes # (issue)

Goals ⚽

  • Add support for padding, transforming, and masking sequential inputs data in MM Pytorch backend
  • The implemented transform classes should:
    - Support multiple targets
    - Be used for training, evaluation, and inference

Implementation Details 🚧

  • Implement TabularBatchPadding to pad a group of sequential inputs
  • Implement TabularPredictNext for generating targets of causal next item prediction
  • Implement TabularPredictLast for generating targets of last item prediction
  • Implement TabularPredictRandom for generating targets of predicting one random item and truncate the sequence so that the random item is at the last position.
  • Implement TabularMaskRandom for masked language modeling training (MLM) strategy
  • Implement TabularMaskLast for masking last item in the sequence, generally used to evaluate models trained with MLM.

Testing Details 🔍

  • Defined tests for padding and the different sequence transformations

@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1161

MASK_PREFIX = "__mask"


class TabularBatchPadding(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename to TabularPadding?

@sararb sararb changed the title first version of sequence transforms applied to Batch input MM Pytorch API: TabularTransform for input tabular sequence Jun 28, 2023
MASK_PREFIX = "__mask"


class TabularPadding(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to break padding out into a smaller PR first, and then do masking afterwards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants