@@ -37,6 +37,8 @@ def get_existing_objects(self):
3737 "has_password" : r ["has_password" ] == "true" ,
3838 "has_rsa_public_key" : r ["has_rsa_public_key" ] == "true" ,
3939 "has_mfa" : r ["has_mfa" ] == "true" ,
40+ "has_pat" : r ["has_pat" ] == "true" ,
41+ "has_workload_identity" : r ["has_workload_identity" ] == "true" ,
4042 "comment" : r ["comment" ] if r ["comment" ] else None ,
4143 }
4244
@@ -96,6 +98,12 @@ def create_object(self, bp: UserBlueprint):
9698 if bp .type :
9799 query .append_nl ("TYPE = {type}" , {"type" : bp .type })
98100
101+ # Workload identity
102+ if bp .workload_identity :
103+ query .append_nl ("WORKLOAD_IDENTITY = (" )
104+ query .append (self ._build_workload_identity_parameters (bp ))
105+ query .append_nl (")" )
106+
99107 # Object and session parameters
100108 query .append (self ._build_common_parameters (bp ))
101109
@@ -123,9 +131,15 @@ def compare_object(self, bp: UserBlueprint, row: dict):
123131 if self ._compare_public_keys (bp , row ):
124132 result = ResolveResult .ALTER
125133
134+ if self ._compare_workload_identity_pre_type (bp , row ):
135+ result = ResolveResult .ALTER
136+
126137 if self ._compare_type (bp , row ):
127138 result = ResolveResult .ALTER
128139
140+ if self ._compare_workload_identity_post_type (bp , row ):
141+ result = ResolveResult .ALTER
142+
129143 if self ._compare_parameters (bp ):
130144 result = ResolveResult .ALTER
131145
@@ -158,6 +172,23 @@ def _build_common_parameters(self, bp: UserBlueprint):
158172
159173 return query
160174
175+ def _build_workload_identity_parameters (self , bp : UserBlueprint ):
176+ query = self .engine .query_builder ()
177+
178+ for param_name , param_value in bp .workload_identity .items ():
179+ query .append_nl (
180+ " {param_name:r} = {param_value:dp}" ,
181+ {
182+ "param_name" : param_name ,
183+ # ISSUER + SUBJECT is the unique key in Snowflake
184+ # SnowDDL has to append env_prefix in order to prevent duplicate key error
185+ # Feel free to adjust this logic for your own custom test environment
186+ "param_value" : f"{ param_value } :{ self .config .env_prefix .rstrip ('_$' )} " if param_name == "SUBJECT" else param_value ,
187+ },
188+ )
189+
190+ return query
191+
161192 def _compare_properties (self , bp : UserBlueprint , row : dict ):
162193 query = self .engine .query_builder ()
163194
@@ -336,6 +367,39 @@ def _compare_parameters(self, bp: UserBlueprint):
336367
337368 return False
338369
370+ def _compare_workload_identity_pre_type (self , bp : UserBlueprint , row : dict ):
371+ if not bp .workload_identity and row ["has_workload_identity" ]:
372+ self .engine .execute_safe_ddl (
373+ "ALTER USER {name:i} UNSET WORKLOAD_IDENTITY" ,
374+ {
375+ "name" : bp .full_name ,
376+ },
377+ )
378+
379+ return True
380+
381+ return False
382+
383+ def _compare_workload_identity_post_type (self , bp : UserBlueprint , row : dict ):
384+ if bp .workload_identity and (not row ["has_workload_identity" ] or self .engine .settings .refresh_workload_identity ):
385+ query = self .engine .query_builder ()
386+
387+ query .append (
388+ "ALTER USER {name:i} SET WORKLOAD_IDENTITY = (" ,
389+ {
390+ "name" : bp .full_name ,
391+ }
392+ )
393+
394+ query .append (self ._build_workload_identity_parameters (bp ))
395+ query .append_nl (")" )
396+
397+ self .engine .execute_safe_ddl (query )
398+
399+ return True
400+
401+ return False
402+
339403 def _check_user_role_grant (self , bp : UserBlueprint ):
340404 user_role = self ._get_user_role_ident (bp )
341405
0 commit comments