@@ -180,6 +180,82 @@ def deploy_component(profile_name, component_path, component_name, component_typ
180
180
code_dir = os .path .join (component_path , subdirs [0 ])
181
181
logger .info (f"Using actual code directory: { code_dir } " )
182
182
183
+ # Check for UDF function signature to automatically fix issues
184
+ function_file = os .path .join (code_dir , "function.py" )
185
+ has_session_param = False
186
+
187
+ if os .path .exists (function_file ):
188
+ try :
189
+ with open (function_file , 'r' ) as f :
190
+ content = f .read ()
191
+
192
+ # Check if we need to fix the function signature for UDFs with session parameters
193
+ if component_type .lower () == "udf" and "def main(session, input_data" in content :
194
+ logger .info ("Detected UDF with session parameter - applying automatic fix" )
195
+ has_session_param = True
196
+
197
+ # Check if we need to automatically modify UDF with wrapper function
198
+ if "def main_with_session(" not in content :
199
+ # Create a backup of the original file
200
+ backup_file = function_file + ".bak"
201
+ with open (backup_file , 'w' ) as f :
202
+ f .write (content )
203
+
204
+ # Modify the content to include a wrapper
205
+ modified_content = content .replace (
206
+ "def main(session, input_data" ,
207
+ """# Original function
208
+ def main_with_session(session, input_data
209
+
210
+ # Wrapper function that Snowflake calls directly
211
+ def main(input_data"""
212
+ )
213
+
214
+ # Add the wrapper implementation at the end
215
+ indent = " " # Default indentation
216
+ lines = modified_content .split ("\n " )
217
+ # Find where to insert wrapper code
218
+ insert_pos = len (lines ) # Default to end of file
219
+ for i in range (len (lines )):
220
+ if lines [i ].strip () == "def main(input_data" :
221
+ # Find indentation level and next non-empty line
222
+ for j in range (i + 1 , len (lines )):
223
+ if lines [j ].strip ():
224
+ indent = lines [j ].split (lines [j ].lstrip ())[0 ]
225
+ break
226
+
227
+ # Add wrapper implementation
228
+ wrapper_code = [
229
+ f"{ indent } # Get session from Snowflake context" ,
230
+ f"{ indent } from snowflake.snowpark.context import get_active_session" ,
231
+ f"{ indent } session = get_active_session()" ,
232
+ f"{ indent } # Call the original function with session" ,
233
+ f"{ indent } return main_with_session(session, input_data)"
234
+ ]
235
+
236
+ # Find the position to insert the wrapper code
237
+ for i in range (len (lines )- 1 , - 1 , - 1 ):
238
+ if "def main(input_data" in lines [i ]:
239
+ # Find where the function body ends
240
+ func_indent = lines [i + 1 ].split (lines [i + 1 ].lstrip ())[0 ]
241
+ for j in range (i + 1 , len (lines )):
242
+ if j == len (lines )- 1 or (lines [j ] and not lines [j ].startswith (func_indent )):
243
+ insert_pos = j
244
+ break
245
+
246
+ # Insert wrapper code at appropriate position
247
+ for code_line in wrapper_code :
248
+ lines .insert (insert_pos , code_line )
249
+ insert_pos += 1
250
+
251
+ # Write modified content back
252
+ with open (function_file , 'w' ) as f :
253
+ f .write ("\n " .join (lines ))
254
+
255
+ logger .info (f"Modified UDF function to handle session parameter" )
256
+ except Exception as e :
257
+ logger .warning (f"Could not check/fix function signature: { str (e )} " )
258
+
183
259
# Log directory contents
184
260
logger .info (f"Component directory structure:" )
185
261
for root , dirs , files in os .walk (component_path ):
@@ -217,16 +293,29 @@ def deploy_component(profile_name, component_path, component_name, component_typ
217
293
import_path = f"@{ stage_name } /{ component_name .replace (' ' , '_' )} /{ zip_filename } "
218
294
219
295
if component_type .lower () == "udf" :
220
- # For Snowpark UDFs that use session parameter
221
- sql = f"""
222
- CREATE OR REPLACE FUNCTION { component_name .replace (' ' , '_' )} (input_data VARIANT)
223
- RETURNS VARIANT
224
- LANGUAGE PYTHON
225
- RUNTIME_VERSION=3.8
226
- PACKAGES = ('snowflake-snowpark-python')
227
- IMPORTS = ('{ import_path } ')
228
- HANDLER = 'function.main'
229
- """
296
+ # For Snowpark UDFs - different SQL based on parameter signature
297
+ if has_session_param :
298
+ # For UDFs with session parameter (Snowpark style)
299
+ sql = f"""
300
+ CREATE OR REPLACE FUNCTION { component_name .replace (' ' , '_' )} (input_data VARIANT)
301
+ RETURNS VARIANT
302
+ LANGUAGE PYTHON
303
+ RUNTIME_VERSION=3.8
304
+ PACKAGES = ('snowflake-snowpark-python')
305
+ HANDLER = 'function.main'
306
+ IMPORTS = ('{ import_path } ')
307
+ """
308
+ else :
309
+ # For basic UDFs without session parameter
310
+ sql = f"""
311
+ CREATE OR REPLACE FUNCTION { component_name .replace (' ' , '_' )} (input_data VARIANT)
312
+ RETURNS VARIANT
313
+ LANGUAGE PYTHON
314
+ RUNTIME_VERSION=3.8
315
+ PACKAGES = ('snowflake-snowpark-python')
316
+ IMPORTS = ('{ import_path } ')
317
+ HANDLER = 'function.main'
318
+ """
230
319
else : # procedure
231
320
sql = f"""
232
321
CREATE OR REPLACE PROCEDURE { component_name .replace (' ' , '_' )} ()
0 commit comments