Skip to content

Commit f95beb9

Browse files
committed
[INTERNAL] Speed up dgan tests
GitOrigin-RevId: 962311af186cafc55b6020c97d6bc90c29c77703
1 parent ae129c5 commit f95beb9

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

tests/timeseries_dgan/test_dgan.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def config() -> DGANConfig:
6767
max_sequence_len=20,
6868
sample_len=5,
6969
batch_size=10,
70-
epochs=10,
70+
epochs=1,
7171
)
7272

7373

@@ -162,10 +162,10 @@ def test_generate():
162162
attributes, features, attributes_shape=(64, 3), features_shape=(64, 20, 2)
163163
)
164164

165-
attributes, features = dg.generate_numpy(200)
165+
attributes, features = dg.generate_numpy(50)
166166

167167
assert_attributes_features_shape(
168-
attributes, features, attributes_shape=(200, 3), features_shape=(200, 20, 2)
168+
attributes, features, attributes_shape=(50, 3), features_shape=(50, 20, 2)
169169
)
170170

171171
attributes, features = dg.generate_numpy(1)
@@ -297,7 +297,7 @@ def test_train_numpy_no_attributes_1(
297297
def test_train_numpy_no_attributes_2(config: DGANConfig):
298298
features = np.random.rand(100, 20, 2)
299299
n_samples = 10
300-
config.epochs = 1
300+
301301
model_attributes_blank = DGAN(config=config)
302302
model_attributes_blank.train_numpy(features=features)
303303
synthetic_attributes, synthetic_features = model_attributes_blank.generate_numpy(
@@ -331,7 +331,6 @@ def test_train_numpy_batch_size_of_1(config: DGANConfig):
331331
# Check model trains when (# of examples) % batch_size == 1.
332332

333333
config.batch_size = 10
334-
config.epochs = 1
335334

336335
features = np.random.rand(91, 20, 2)
337336
attributes = np.random.randint(0, 3, (91, 1))
@@ -506,7 +505,6 @@ def test_train_dataframe_wide_no_attributes(config: DGANConfig):
506505

507506
config.max_sequence_len = 4
508507
config.sample_len = 1
509-
config.epochs = 1
510508

511509
dg = DGAN(config=config)
512510
dg.train_dataframe(df=df, df_style=DfStyle.WIDE)
@@ -1901,7 +1899,6 @@ def test_save_and_load(
19011899
attributes, attribute_types = attribute_data
19021900
features, feature_types = feature_data
19031901

1904-
config.epochs = 1
19051902
config.use_attribute_discriminator = use_attribute_discriminator
19061903
config.apply_example_scaling = apply_example_scaling
19071904
config.attribute_noise_dim = noise_dim
@@ -1957,7 +1954,6 @@ def test_save_and_load_no_attributes(
19571954
):
19581955
features, feature_types = feature_data
19591956

1960-
config.epochs = 1
19611957
config.use_attribute_discriminator = use_attribute_discriminator
19621958
config.apply_example_scaling = apply_example_scaling
19631959
config.attribute_noise_dim = noise_dim
@@ -2009,7 +2005,6 @@ def test_save_and_load_dataframe_with_attributes(config: DGANConfig, tmp_path):
20092005
)
20102006
config.max_sequence_len = 4
20112007
config.sample_len = 1
2012-
config.epochs = 1
20132008

20142009
dg = DGAN(config=config)
20152010

@@ -2044,7 +2039,6 @@ def test_attribute_and_feature_overlap(config: DGANConfig):
20442039
)
20452040
config.max_sequence_len = 4
20462041
config.sample_len = 1
2047-
config.epochs = 1
20482042

20492043
dg = DGAN(config=config)
20502044

@@ -2070,7 +2064,6 @@ def test_save_and_load_dataframe_no_attributes(config: DGANConfig, tmp_path):
20702064

20712065
config.max_sequence_len = 3
20722066
config.sample_len = 1
2073-
config.epochs = 1
20742067

20752068
dg = DGAN(config=config)
20762069

@@ -2101,7 +2094,6 @@ def test_dataframe_long_no_continuous_features(config: DGANConfig):
21012094

21022095
config.max_sequence_len = 3
21032096
config.sample_len = 1
2104-
config.epochs = 1
21052097

21062098
dg = DGAN(config=config)
21072099

@@ -2124,7 +2116,6 @@ def test_dataframe_wide_no_continuous_features(config: DGANConfig):
21242116

21252117
config.max_sequence_len = 3
21262118
config.sample_len = 1
2127-
config.epochs = 1
21282119

21292120
dg = DGAN(config=config)
21302121

@@ -2146,7 +2137,6 @@ def test_dataframe_long_partial_example(config: DGANConfig):
21462137

21472138
config.max_sequence_len = 10
21482139
config.sample_len = 1
2149-
config.epochs = 1
21502140

21512141
dg = DGAN(config=config)
21522142

@@ -2170,7 +2160,6 @@ def test_dataframe_long_one_and_partial_example(config: DGANConfig):
21702160

21712161
config.max_sequence_len = 5
21722162
config.sample_len = 1
2173-
config.epochs = 1
21742163

21752164
dg = DGAN(config=config)
21762165

@@ -2203,7 +2192,6 @@ def test_dataframe_variable_sequences(config: DGANConfig):
22032192

22042193
config.max_sequence_len = 8
22052194
config.sample_len = 1
2206-
config.epochs = 1
22072195

22082196
dg = DGAN(config=config)
22092197

0 commit comments

Comments
 (0)