diff --git a/hibernate-core/src/main/java/org/hibernate/boot/jaxb/mapping/spi/JaxbPluralAttribute.java b/hibernate-core/src/main/java/org/hibernate/boot/jaxb/mapping/spi/JaxbPluralAttribute.java
index 6e51a13fe008..c52511c0aa01 100644
--- a/hibernate-core/src/main/java/org/hibernate/boot/jaxb/mapping/spi/JaxbPluralAttribute.java
+++ b/hibernate-core/src/main/java/org/hibernate/boot/jaxb/mapping/spi/JaxbPluralAttribute.java
@@ -26,6 +26,8 @@ public interface JaxbPluralAttribute extends JaxbPersistentAttribute, JaxbLockab
 	JaxbCollectionIdImpl getCollectionId();
 	void setCollectionId(JaxbCollectionIdImpl id);
 
+	Integer getBatchSize();
+	void setBatchSize(Integer size);
 
 	LimitedCollectionClassification getClassification();
 	void setClassification(LimitedCollectionClassification value);
diff --git a/hibernate-core/src/main/java/org/hibernate/boot/models/xml/internal/attr/CommonPluralAttributeProcessing.java b/hibernate-core/src/main/java/org/hibernate/boot/models/xml/internal/attr/CommonPluralAttributeProcessing.java
index e4746da126ef..b77deb2305a8 100644
--- a/hibernate-core/src/main/java/org/hibernate/boot/models/xml/internal/attr/CommonPluralAttributeProcessing.java
+++ b/hibernate-core/src/main/java/org/hibernate/boot/models/xml/internal/attr/CommonPluralAttributeProcessing.java
@@ -11,6 +11,7 @@
 import org.hibernate.boot.jaxb.mapping.spi.JaxbPluralFetchModeImpl;
 import org.hibernate.boot.models.HibernateAnnotations;
 import org.hibernate.boot.models.JpaAnnotations;
+import org.hibernate.boot.models.annotations.internal.BatchSizeAnnotation;
 import org.hibernate.boot.models.annotations.internal.FetchAnnotation;
 import org.hibernate.boot.models.annotations.internal.MapKeyClassJpaAnnotation;
 import org.hibernate.boot.models.annotations.internal.MapKeyColumnJpaAnnotation;
@@ -62,6 +63,14 @@ public static void applyPluralAttributeStructure(
 			}
 		}
 
+		if ( jaxbPluralAttribute.getBatchSize() != null ) {
+			final BatchSizeAnnotation batchSizeAnnotation = (BatchSizeAnnotation) memberDetails.applyAnnotationUsage(
+					HibernateAnnotations.BATCH_SIZE,
+					buildingContext
+			);
+			batchSizeAnnotation.size( jaxbPluralAttribute.getBatchSize() );
+		}
+
 		// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 		// collection-structure
 
diff --git a/hibernate-core/src/main/resources/org/hibernate/xsd/mapping/mapping-7.0.xsd b/hibernate-core/src/main/resources/org/hibernate/xsd/mapping/mapping-7.0.xsd
index 1118b61457f8..dd5129f41c72 100644
--- a/hibernate-core/src/main/resources/org/hibernate/xsd/mapping/mapping-7.0.xsd
+++ b/hibernate-core/src/main/resources/org/hibernate/xsd/mapping/mapping-7.0.xsd
@@ -3313,6 +3313,9 @@
                 
             
 
+            
+            
+
             
 
             
diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/BatchFetchTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/BatchFetchTest.java
index 8b33fa53c275..026c6a47e57c 100644
--- a/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/BatchFetchTest.java
+++ b/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/BatchFetchTest.java
@@ -4,13 +4,7 @@
  */
 package org.hibernate.orm.test.batchfetch;
 
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
 import org.hibernate.Hibernate;
-import org.hibernate.cfg.AvailableSettings;
-
 import org.hibernate.testing.orm.junit.DomainModel;
 import org.hibernate.testing.orm.junit.ServiceRegistry;
 import org.hibernate.testing.orm.junit.SessionFactory;
@@ -19,6 +13,12 @@
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hibernate.cfg.CacheSettings.USE_SECOND_LEVEL_CACHE;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -27,166 +27,118 @@
 /**
  * @author Gavin King
  */
+@SuppressWarnings("JUnitMalformedDeclaration")
 @DomainModel(
-		xmlMappings = "org/hibernate/orm/test/batchfetch/ProductLine.hbm.xml",
+		xmlMappings = "org/hibernate/orm/test/batchfetch/ProductLine.xml",
 		annotatedClasses = BatchLoadableEntity.class
 )
-@SessionFactory(
-		generateStatistics = true
-)
-@ServiceRegistry(
-		settings = {
-				@Setting(name = AvailableSettings.USE_SECOND_LEVEL_CACHE, value = "false")
-		}
-)
+@SessionFactory(generateStatistics = true)
+@ServiceRegistry(settings = @Setting(name = USE_SECOND_LEVEL_CACHE, value = "false"))
 public class BatchFetchTest {
 
 	@SuppressWarnings("unchecked")
 	@Test
 	public void testBatchFetch(SessionFactoryScope scope) {
-		ProductLine ossProductLine = new ProductLine();
-		Model hibernateModel = new Model( ossProductLine );
-		scope.inTransaction(
-				session -> {
-					ProductLine cars = new ProductLine();
-					cars.setDescription( "Cars" );
-					Model monaro = new Model( cars );
-					monaro.setName( "monaro" );
-					monaro.setDescription( "Holden Monaro" );
-					Model hsv = new Model( cars );
-					hsv.setName( "hsv" );
-					hsv.setDescription( "Holden Commodore HSV" );
-					session.persist( cars );
-
-					ossProductLine.setDescription( "OSS" );
-					Model jboss = new Model( ossProductLine );
-					jboss.setName( "JBoss" );
-					jboss.setDescription( "JBoss Application Server" );
-
-					hibernateModel.setName( "Hibernate" );
-					hibernateModel.setDescription( "Hibernate" );
-					Model cache = new Model( ossProductLine );
-					cache.setName( "JBossCache" );
-					cache.setDescription( "JBoss TreeCache" );
-					session.persist( ossProductLine );
-				}
-		);
-
-		scope.getSessionFactory().getCache().evictEntityData( Model.class );
-		scope.getSessionFactory().getCache().evictEntityData( ProductLine.class );
-
-		scope.inTransaction(
-				session -> {
-					List list = session.createQuery( "from ProductLine pl order by pl.description" )
-							.list();
-					ProductLine cars = list.get( 0 );
-					ProductLine oss = list.get( 1 );
-					assertFalse( Hibernate.isInitialized( cars.getModels() ) );
-					assertFalse( Hibernate.isInitialized( oss.getModels() ) );
-					assertEquals( 2, cars.getModels().size() ); //fetch both collections
-					assertTrue( Hibernate.isInitialized( cars.getModels() ) );
-					assertTrue( Hibernate.isInitialized( oss.getModels() ) );
-
-					session.clear();
-
-					List models = session.createQuery( "from Model m" ).list();
-					Model hibernate = session.get( Model.class, hibernateModel.getId() );
-					hibernate.getProductLine().getId();
-					for ( Model aList : models ) {
-						assertFalse( Hibernate.isInitialized( aList.getProductLine() ) );
-					}
-					assertEquals( hibernate.getProductLine().getDescription(), "OSS" ); //fetch both productlines
-
-					session.clear();
-
-					Iterator iter = session.createQuery( "from Model" ).list().iterator();
-					models = new ArrayList();
-					while ( iter.hasNext() ) {
-						models.add( iter.next() );
-					}
-					Model m = models.get( 0 );
-					m.getDescription(); //fetch a batch of 4
-
-					session.clear();
-
-					list = session.createQuery( "from ProductLine" ).list();
-					ProductLine pl = list.get( 0 );
-					ProductLine pl2 = list.get( 1 );
-					session.evict( pl2 );
-					pl.getModels().size(); //fetch just one collection! (how can we write an assertion for that??)
-				}
-		);
-
-		scope.inTransaction(
-				session -> {
-					List list = session.createQuery( "from ProductLine pl order by pl.description" )
-							.list();
-					ProductLine cars = list.get( 0 );
-					ProductLine oss = list.get( 1 );
-					assertEquals( cars.getModels().size(), 2 );
-					assertEquals( oss.getModels().size(), 3 );
-					session.remove( cars );
-					session.remove( oss );
-				}
-		);
+		ProductLine ossProductLine = new ProductLine( "OSS" );
+		Model hibernateModel = new Model( "Hibernate", "Hibernate", ossProductLine );
+		scope.inTransaction( (session) -> {
+			ProductLine cars = new ProductLine( "Cars" );
+			new Model( "monaro", "Holden Monaro", cars );
+			new Model( "hsv", "Holden Commodore HSV", cars );
+			session.persist( cars );
+
+			ossProductLine.setDescription( "OSS" );
+			new Model( "JBoss", "JBoss Application Server", ossProductLine );
+			new Model( "JBossCache", "JBoss TreeCache", ossProductLine );
+			session.persist( ossProductLine );
+		} );
+
+		scope.inTransaction( (session) -> {
+			List list = session.createQuery( "from ProductLine pl order by pl.description" ).list();
+			ProductLine cars = list.get( 0 );
+			ProductLine oss = list.get( 1 );
+			assertFalse( Hibernate.isInitialized( cars.getModels() ) );
+			assertFalse( Hibernate.isInitialized( oss.getModels() ) );
+			assertEquals( 2, cars.getModels().size() ); //fetch both collections
+			assertTrue( Hibernate.isInitialized( cars.getModels() ) );
+			assertTrue( Hibernate.isInitialized( oss.getModels() ) );
+		} );
+
+		scope.inTransaction( (session) -> {
+			List models = session.createQuery( "from Model m" ).list();
+			Model hibernate = session.find( Model.class, hibernateModel.getId() );
+			hibernate.getProductLine().getId();
+			for ( Model aList : models ) {
+				assertFalse( Hibernate.isInitialized( aList.getProductLine() ) );
+			}
+			//fetch both product lines
+			assertThat( hibernate.getProductLine().getDescription() ).isEqualTo( "OSS" );
+		} );
+
+		scope.inTransaction( (session) -> {
+			Iterator iter = session.createQuery( "from Model" ).list().iterator();
+			ArrayList models = new ArrayList<>();
+			while ( iter.hasNext() ) {
+				models.add( iter.next() );
+			}
+			Model m = models.get( 0 );
+			m.getDescription(); //fetch a batch of 4
+
+			session.clear();
+
+			List list = session.createQuery( "from ProductLine" ).list();
+			ProductLine pl = list.get( 0 );
+			ProductLine pl2 = list.get( 1 );
+			session.evict( pl2 );
+			pl.getModels().size(); //fetch just one collection! (how can we write an assertion for that??)
+		} );
+
+		scope.inTransaction( (session) -> {
+			List list = session.createQuery( "from ProductLine pl order by pl.description" ).list();
+			ProductLine cars = list.get( 0 );
+			ProductLine oss = list.get( 1 );
+			assertThat( cars.getModels().size() ).isEqualTo( 2 );
+			assertThat( oss.getModels().size() ).isEqualTo( 3 );
+		} );
 	}
 
 	@Test
-	@SuppressWarnings("unchecked")
 	public void testBatchFetch2(SessionFactoryScope scope) {
 		int size = 32 + 14;
-		scope.inTransaction(
-				session -> {
-					for ( int i = 0; i < size; i++ ) {
-						session.persist( new BatchLoadableEntity( i ) );
-					}
-				}
-		);
-
-		scope.inTransaction(
-				session -> {
-					// load them all as proxies
-					for ( int i = 0; i < size; i++ ) {
-						BatchLoadableEntity entity = session.getReference( BatchLoadableEntity.class, i );
-						assertFalse( Hibernate.isInitialized( entity ) );
-					}
-					scope.getSessionFactory().getStatistics().clear();
-					// now start initializing them...
-					for ( int i = 0; i < size; i++ ) {
-						BatchLoadableEntity entity = session.getReference( BatchLoadableEntity.class, i );
-						Hibernate.initialize( entity );
-						assertTrue( Hibernate.isInitialized( entity ) );
-					}
-					// so at this point, all entities are initialized.  see how many fetches were performed.
-					final int expectedFetchCount;
-//		if ( sessionFactory().getSettings().getBatchFetchStyle() == BatchFetchStyle.LEGACY ) {
-//			expectedFetchCount = 3; // (32 + 10 + 4)
-//		}
-//		else if ( sessionFactory().getSettings().getBatchFetchStyle() == BatchFetchStyle.DYNAMIC ) {
-//			expectedFetchCount = 2;  // (32 + 14) : because we limited batch-size to 32
-//		}
-//		else {
-					// PADDED
-					expectedFetchCount = 2; // (32 + 16*) with the 16 being padded
-//		}
-					assertEquals(
-							expectedFetchCount,
-							scope.getSessionFactory().getStatistics()
-									.getEntityStatistics( BatchLoadableEntity.class.getName() )
-									.getFetchCount()
-					);
-				}
-		);
+		scope.inTransaction( (session) -> {
+			for ( int i = 0; i < size; i++ ) {
+				session.persist( new BatchLoadableEntity( i ) );
+			}
+		} );
+
+		scope.inTransaction( (session) -> {
+			// load them all as proxies
+			for ( int i = 0; i < size; i++ ) {
+				BatchLoadableEntity entity = session.getReference( BatchLoadableEntity.class, i );
+				assertFalse( Hibernate.isInitialized( entity ) );
+			}
+			scope.getSessionFactory().getStatistics().clear();
+			// now start initializing them...
+			for ( int i = 0; i < size; i++ ) {
+				BatchLoadableEntity entity = session.getReference( BatchLoadableEntity.class, i );
+				Hibernate.initialize( entity );
+				assertTrue( Hibernate.isInitialized( entity ) );
+			}
+			// so at this point, all entities are initialized.  see how many fetches were performed.
+			final int expectedFetchCount;
+			expectedFetchCount = 2; // (32 + 16*) with the 16 being padded
+
+			assertEquals(
+					expectedFetchCount,
+					scope.getSessionFactory().getStatistics()
+							.getEntityStatistics( BatchLoadableEntity.class.getName() )
+							.getFetchCount()
+			);
+		} );
 	}
 
 	@AfterEach
 	public void tearDown(SessionFactoryScope scope) {
-		scope.inTransaction(
-				session -> {
-					session.createQuery( "delete BatchLoadableEntity" ).executeUpdate();
-					session.createQuery( "delete Model" ).executeUpdate();
-					session.createQuery( "delete ProductLine" ).executeUpdate();
-				}
-		);
+		scope.dropData();
 	}
 }
diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/Model.java b/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/Model.java
index 72aecb38375c..3df626cf2127 100644
--- a/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/Model.java
+++ b/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/Model.java
@@ -9,16 +9,19 @@
  * @author Gavin King
  */
 public class Model {
-	private String id;
+	private Integer id;
 	private String name;
 	private String description;
 	private ProductLine productLine;
 
-	Model() {}
+	Model() {
+	}
 
-	public Model(ProductLine pl) {
-		this.productLine = pl;
-		pl.getModels().add(this);
+	public Model(String name, String description, ProductLine productLine) {
+		this.name = name;
+		this.description = description;
+		this.productLine = productLine;
+		productLine.getModels().add(this);
 	}
 
 	public String getDescription() {
@@ -27,10 +30,10 @@ public String getDescription() {
 	public void setDescription(String description) {
 		this.description = description;
 	}
-	public String getId() {
+	public Integer getId() {
 		return id;
 	}
-	public void setId(String id) {
+	public void setId(Integer id) {
 		this.id = id;
 	}
 	public String getName() {
diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/ProductLine.java b/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/ProductLine.java
index 81bc70b78681..fd44ea86fabb 100644
--- a/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/ProductLine.java
+++ b/hibernate-core/src/test/java/org/hibernate/orm/test/batchfetch/ProductLine.java
@@ -10,21 +10,27 @@
  * @author Gavin King
  */
 public class ProductLine {
-
-	private String id;
+	private Integer id;
 	private String description;
 	private Set models = new HashSet();
 
+	public ProductLine() {
+	}
+
+	public ProductLine(String description) {
+		this.description = description;
+	}
+
 	public String getDescription() {
 		return description;
 	}
 	public void setDescription(String description) {
 		this.description = description;
 	}
-	public String getId() {
+	public Integer getId() {
 		return id;
 	}
-	public void setId(String id) {
+	public void setId(Integer id) {
 		this.id = id;
 	}
 	public Set getModels() {
diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/ASTParserLoadingTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/ASTParserLoadingTest.java
index 2ad603699601..1bf0e275db38 100644
--- a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/ASTParserLoadingTest.java
+++ b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/ASTParserLoadingTest.java
@@ -105,7 +105,7 @@
 				"/org/hibernate/orm/test/hql/ComponentContainer.hbm.xml",
 				"/org/hibernate/orm/test/hql/VariousKeywordPropertyEntity.hbm.xml",
 				"/org/hibernate/orm/test/hql/Constructor.hbm.xml",
-				"/org/hibernate/orm/test/batchfetch/ProductLine.hbm.xml",
+				"/org/hibernate/orm/test/batchfetch/ProductLine.xml",
 				"/org/hibernate/orm/test/cid/Customer.hbm.xml",
 				"/org/hibernate/orm/test/cid/Order.hbm.xml",
 				"/org/hibernate/orm/test/cid/LineItem.hbm.xml",
diff --git a/hibernate-core/src/test/resources/org/hibernate/orm/test/batchfetch/ProductLine.hbm.xml b/hibernate-core/src/test/resources/org/hibernate/orm/test/batchfetch/ProductLine.hbm.xml
deleted file mode 100644
index 2653901a0264..000000000000
--- a/hibernate-core/src/test/resources/org/hibernate/orm/test/batchfetch/ProductLine.hbm.xml
+++ /dev/null
@@ -1,65 +0,0 @@
-
-
-
-
-
-
-
-
-    
-    
-    	
-    		
-    	
-    	
-    	
-    	
-    	
-    		
-    		
-    	
-    	
-	
-
-    
-    
-    	
-    		
-    	
-    	
-    	
-    		
-    	
-    	
-    	
-    	
-	
-
-
diff --git a/hibernate-core/src/test/resources/org/hibernate/orm/test/batchfetch/ProductLine.xml b/hibernate-core/src/test/resources/org/hibernate/orm/test/batchfetch/ProductLine.xml
new file mode 100644
index 000000000000..0dfe2ff1ceaa
--- /dev/null
+++ b/hibernate-core/src/test/resources/org/hibernate/orm/test/batchfetch/ProductLine.xml
@@ -0,0 +1,53 @@
+
+
+
+
+	
+	org.hibernate.orm.test.batchfetch
+
+	
+		64
+		
+			
+				
+			
+			
+				
+			
+			
+				64
+				
+					
+				
+			
+		
+	
+
+	
+		64
+		
+			
+				
+			
+			
+				
+			
+			
+				
+			
+			
+				
+			
+		
+	
+