@@ -3861,58 +3861,114 @@ bool AMDGPUDAGToDAGISel::SelectVOP3OpSelMods(SDValue In, SDValue &Src,
3861
3861
return SelectVOP3Mods (In, Src, SrcMods);
3862
3862
}
3863
3863
3864
+ // Match lowered fpext from bf16 to f32. This is a bit operation extending
3865
+ // a 16-bit value with 16-bit of zeroes at LSB:
3866
+ //
3867
+ // 1. (f32 (bitcast (build_vector (i16 0), (i16 (bitcast bf16:val)))))
3868
+ // 2. (f32 (bitcast (and i32:val, 0xffff0000))) -> IsExtractHigh = true
3869
+ // 3. (f32 (bitcast (shl i32:va, 16) -> IsExtractHigh = false
3870
+ static SDValue matchBF16FPExtendLike (SDValue Op, bool &IsExtractHigh) {
3871
+ if (Op.getValueType () != MVT::f32 || Op.getOpcode () != ISD::BITCAST)
3872
+ return SDValue ();
3873
+ Op = Op.getOperand (0 );
3874
+
3875
+ IsExtractHigh = false ;
3876
+ if (Op.getValueType () == MVT::v2i16 && Op.getOpcode () == ISD::BUILD_VECTOR) {
3877
+ auto Low16 = dyn_cast<ConstantSDNode>(Op.getOperand (0 ));
3878
+ if (!Low16 || !Low16->isZero ())
3879
+ return SDValue ();
3880
+ Op = stripBitcast (Op.getOperand (1 ));
3881
+ if (Op.getValueType () != MVT::bf16 )
3882
+ return SDValue ();
3883
+ return Op;
3884
+ }
3885
+
3886
+ if (Op.getValueType () != MVT::i32 )
3887
+ return SDValue ();
3888
+
3889
+ if (Op.getOpcode () == ISD::AND) {
3890
+ if (auto Mask = dyn_cast<ConstantSDNode>(Op.getOperand (1 ))) {
3891
+ if (Mask->getZExtValue () == 0xffff0000 ) {
3892
+ IsExtractHigh = true ;
3893
+ return Op.getOperand (0 );
3894
+ }
3895
+ }
3896
+ return SDValue ();
3897
+ }
3898
+
3899
+ if (Op.getOpcode () == ISD::SHL) {
3900
+ if (auto Amt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ))) {
3901
+ if (Amt->getZExtValue () == 16 )
3902
+ return Op.getOperand (0 );
3903
+ }
3904
+ }
3905
+
3906
+ return SDValue ();
3907
+ }
3908
+
3864
3909
// The return value is not whether the match is possible (which it always is),
3865
3910
// but whether or not it a conversion is really used.
3866
3911
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl (SDValue In, SDValue &Src,
3867
- unsigned &Mods) const {
3912
+ unsigned &Mods,
3913
+ MVT VT) const {
3868
3914
Mods = 0 ;
3869
3915
SelectVOP3ModsImpl (In, Src, Mods);
3870
3916
3917
+ bool IsExtractHigh = false ;
3871
3918
if (Src.getOpcode () == ISD::FP_EXTEND) {
3872
3919
Src = Src.getOperand (0 );
3873
- assert (Src.getValueType () == MVT::f16 );
3874
- Src = stripBitcast (Src);
3920
+ } else if (VT == MVT::bf16 ) {
3921
+ SDValue B16 = matchBF16FPExtendLike (Src, IsExtractHigh);
3922
+ if (!B16)
3923
+ return false ;
3924
+ Src = B16;
3925
+ } else
3926
+ return false ;
3875
3927
3876
- // Be careful about folding modifiers if we already have an abs. fneg is
3877
- // applied last, so we don't want to apply an earlier fneg.
3878
- if ((Mods & SISrcMods::ABS) == 0 ) {
3879
- unsigned ModsTmp;
3880
- SelectVOP3ModsImpl (Src, Src, ModsTmp);
3928
+ if (Src.getValueType () != VT &&
3929
+ (VT != MVT::bf16 || Src.getValueType () != MVT::i32 ))
3930
+ return false ;
3881
3931
3882
- if ((ModsTmp & SISrcMods::NEG) != 0 )
3883
- Mods ^= SISrcMods::NEG;
3932
+ Src = stripBitcast (Src);
3884
3933
3885
- if ((ModsTmp & SISrcMods::ABS) != 0 )
3886
- Mods |= SISrcMods::ABS;
3887
- }
3934
+ // Be careful about folding modifiers if we already have an abs. fneg is
3935
+ // applied last, so we don't want to apply an earlier fneg.
3936
+ if ((Mods & SISrcMods::ABS) == 0 ) {
3937
+ unsigned ModsTmp;
3938
+ SelectVOP3ModsImpl (Src, Src, ModsTmp);
3939
+
3940
+ if ((ModsTmp & SISrcMods::NEG) != 0 )
3941
+ Mods ^= SISrcMods::NEG;
3888
3942
3889
- // op_sel/op_sel_hi decide the source type and source.
3890
- // If the source's op_sel_hi is set, it indicates to do a conversion from fp16.
3891
- // If the sources's op_sel is set, it picks the high half of the source
3892
- // register.
3943
+ if ((ModsTmp & SISrcMods::ABS) != 0 )
3944
+ Mods |= SISrcMods::ABS;
3945
+ }
3893
3946
3894
- Mods |= SISrcMods::OP_SEL_1;
3895
- if (isExtractHiElt (Src, Src)) {
3896
- Mods |= SISrcMods::OP_SEL_0;
3947
+ // op_sel/op_sel_hi decide the source type and source.
3948
+ // If the source's op_sel_hi is set, it indicates to do a conversion from
3949
+ // fp16. If the sources's op_sel is set, it picks the high half of the source
3950
+ // register.
3897
3951
3898
- // TODO: Should we try to look for neg/abs here?
3899
- }
3952
+ Mods |= SISrcMods::OP_SEL_1;
3953
+ if (IsExtractHigh ||
3954
+ (Src.getValueSizeInBits () == 16 && isExtractHiElt (Src, Src))) {
3955
+ Mods |= SISrcMods::OP_SEL_0;
3900
3956
3901
- // Prevent unnecessary subreg COPY to VGPR_16
3902
- if (Src.getOpcode () == ISD::TRUNCATE &&
3903
- Src.getOperand (0 ).getValueType () == MVT::i32 ) {
3904
- Src = Src.getOperand (0 );
3905
- }
3906
- return true ;
3957
+ // TODO: Should we try to look for neg/abs here?
3907
3958
}
3908
3959
3909
- return false ;
3960
+ // Prevent unnecessary subreg COPY to VGPR_16
3961
+ if (Src.getOpcode () == ISD::TRUNCATE &&
3962
+ Src.getOperand (0 ).getValueType () == MVT::i32 ) {
3963
+ Src = Src.getOperand (0 );
3964
+ }
3965
+ return true ;
3910
3966
}
3911
3967
3912
3968
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt (SDValue In, SDValue &Src,
3913
3969
SDValue &SrcMods) const {
3914
3970
unsigned Mods = 0 ;
3915
- if (!SelectVOP3PMadMixModsImpl (In, Src, Mods))
3971
+ if (!SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT:: f16 ))
3916
3972
return false ;
3917
3973
SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
3918
3974
return true ;
@@ -3921,7 +3977,24 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
3921
3977
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods (SDValue In, SDValue &Src,
3922
3978
SDValue &SrcMods) const {
3923
3979
unsigned Mods = 0 ;
3924
- SelectVOP3PMadMixModsImpl (In, Src, Mods);
3980
+ SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT::f16 );
3981
+ SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
3982
+ return true ;
3983
+ }
3984
+
3985
+ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16ModsExt (SDValue In, SDValue &Src,
3986
+ SDValue &SrcMods) const {
3987
+ unsigned Mods = 0 ;
3988
+ if (!SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT::bf16 ))
3989
+ return false ;
3990
+ SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
3991
+ return true ;
3992
+ }
3993
+
3994
+ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16Mods (SDValue In, SDValue &Src,
3995
+ SDValue &SrcMods) const {
3996
+ unsigned Mods = 0 ;
3997
+ SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT::bf16 );
3925
3998
SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
3926
3999
return true ;
3927
4000
}
0 commit comments