Skip to content

Commit 1748f21

Browse files
committed
dev changes deployment for the udfs
1 parent c8a9266 commit 1748f21

File tree

2 files changed

+130
-22
lines changed

2 files changed

+130
-22
lines changed

scripts/deployment_files/check_and_fix_udf.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import os
22
import sys
33
import argparse
4+
import logging
5+
6+
# Setup logging
7+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8+
logger = logging.getLogger(__name__)
49

510
def fix_udf_function(udf_path):
611
"""
@@ -10,7 +15,7 @@ def fix_udf_function(udf_path):
1015
function_file = os.path.join(udf_path, "function.py")
1116

1217
if not os.path.exists(function_file):
13-
print(f"Error: Function file not found at: {function_file}")
18+
logger.error(f"Error: Function file not found at: {function_file}")
1419
return False
1520

1621
# Back up original file
@@ -24,9 +29,15 @@ def fix_udf_function(udf_path):
2429
f.write(content)
2530

2631
# Check if the function has a session parameter
27-
if "def main(session, input_data" in content:
28-
print("Found Snowpark UDF with session parameter")
32+
# We need to be more thorough in our detection
33+
if "def main(session" in content or "def main( session" in content:
34+
logger.info("Found Snowpark UDF with session parameter")
2935

36+
# Check if it's already been wrapped
37+
if "def main_with_session(session" in content:
38+
logger.info("Function is already wrapped - no need to modify")
39+
return True
40+
3041
# Add wrapper function that gets session from snowflake.snowpark.functions
3142
modified_content = content.replace(
3243
"def main(session, input_data",
@@ -69,27 +80,61 @@ def main(input_data"""
6980
with open(function_file, 'w') as f:
7081
f.write(modified_content)
7182

72-
print(f"✅ Updated {function_file} with wrapper function")
73-
print(f"Original file backed up to {backup_file}")
83+
logger.info(f"✅ Updated {function_file} with wrapper function")
84+
logger.info(f"Original file backed up to {backup_file}")
7485
return True
7586
else:
76-
print("No session parameter detected, no changes needed")
87+
logger.info("No session parameter detected, no changes needed")
7788
return False
7889

7990
except Exception as e:
80-
print(f"Error fixing UDF: {e}")
91+
logger.error(f"Error fixing UDF: {e}")
8192
# Try to restore backup if it exists
8293
if os.path.exists(backup_file):
83-
print("Restoring backup...")
94+
logger.info("Restoring backup...")
8495
with open(backup_file, 'r') as f:
8596
original = f.read()
8697
with open(function_file, 'w') as f:
8798
f.write(original)
8899
return False
89100

101+
def analyze_udf_file(udf_path):
102+
"""Analyze the UDF file to determine its signature."""
103+
function_file = os.path.join(udf_path, "function.py")
104+
105+
if not os.path.exists(function_file):
106+
logger.error(f"Function file not found: {function_file}")
107+
return
108+
109+
with open(function_file, 'r') as f:
110+
content = f.read()
111+
112+
logger.info(f"Analyzing UDF in {function_file}")
113+
114+
# Extract the main function signature
115+
import re
116+
sig_match = re.search(r'def\s+main\s*\((.*?)\)', content)
117+
if sig_match:
118+
params = sig_match.group(1)
119+
logger.info(f"Function signature parameters: '{params}'")
120+
if "session" in params:
121+
logger.info("This UDF uses the Snowpark session parameter and needs wrapping")
122+
return True
123+
124+
return False
125+
90126
if __name__ == "__main__":
91127
parser = argparse.ArgumentParser(description='Fix UDF function for Snowflake compatibility')
92128
parser.add_argument('udf_path', help='Path to UDF directory containing function.py')
129+
parser.add_argument('--analyze', action='store_true', help='Just analyze, don\'t fix')
93130

94131
args = parser.parse_args()
95-
fix_udf_function(args.udf_path)
132+
133+
if args.analyze:
134+
analyze_udf_file(args.udf_path)
135+
else:
136+
needs_fix = analyze_udf_file(args.udf_path)
137+
if needs_fix:
138+
fix_udf_function(args.udf_path)
139+
else:
140+
logger.info("UDF doesn't need fixing")

scripts/deployment_files/snowflake_deployer.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,41 @@ def create_snowflake_connection(conn_config):
147147
logger.error(f"Failed to create Snowflake connection: {str(e)}")
148148
raise
149149

150+
def analyze_function_signature(function_file):
151+
"""Analyze the function signature to determine parameter structure."""
152+
try:
153+
import re
154+
with open(function_file, 'r') as f:
155+
content = f.read()
156+
157+
# Look for the main function definition
158+
signature_match = re.search(r'def\s+main\s*\((.*?)\)', content)
159+
if signature_match:
160+
params = signature_match.group(1).strip()
161+
logger.info(f"Function signature parameters: '{params}'")
162+
163+
# Count parameters (excluding session if present)
164+
param_list = [p.strip() for p in params.split(',') if p.strip()]
165+
166+
# Check if session is a parameter
167+
has_session = any(p.strip().startswith('session') for p in param_list)
168+
param_count = len(param_list)
169+
170+
return {
171+
'has_session': has_session,
172+
'param_count': param_count,
173+
'param_list': param_list
174+
}
175+
except Exception as e:
176+
logger.error(f"Error analyzing function signature: {e}")
177+
178+
# Default fallback
179+
return {
180+
'has_session': False,
181+
'param_count': 1,
182+
'param_list': ['input_data']
183+
}
184+
150185
def deploy_component(profile_name, component_path, component_name, component_type):
151186
"""Deploy a component to Snowflake."""
152187
logger.info(f"Deploying component: {component_name} ({component_type})")
@@ -196,9 +231,14 @@ def deploy_component(profile_name, component_path, component_name, component_typ
196231
logger.info(f"Using actual code directory: {code_dir}")
197232

198233
# Check and fix UDF function signature if necessary
199-
if component_type.lower() == "udf":
200-
logger.info(f"Checking and fixing UDF function signature for {component_name}")
201-
os.system(f"python scripts/deployment_files/check_and_fix_udf.py {code_dir}")
234+
function_file = os.path.join(code_dir, "function.py")
235+
signature_info = {'has_session': False, 'param_count': 1, 'param_list': ['input_data']}
236+
237+
if component_type.lower() == "udf" and os.path.exists(function_file):
238+
logger.info(f"Analyzing UDF function signature for {component_name}")
239+
signature_info = analyze_function_signature(function_file)
240+
logger.info(f"Function analysis: Session={signature_info['has_session']}, "
241+
f"Param Count={signature_info['param_count']}")
202242

203243
# Log directory contents
204244
logger.info(f"Component directory structure:")
@@ -224,16 +264,39 @@ def deploy_component(profile_name, component_path, component_name, component_typ
224264
import_path = f"@{stage_name}/{component_name.replace(' ', '_')}/{zip_filename}"
225265

226266
if component_type.lower() == "udf":
227-
# For Snowpark UDFs - different SQL based on parameter signature
228-
sql = f"""
229-
CREATE OR REPLACE FUNCTION {component_name.replace(' ', '_')}(input_data VARIANT)
230-
RETURNS VARIANT
231-
LANGUAGE PYTHON
232-
RUNTIME_VERSION=3.8
233-
PACKAGES = ('snowflake-snowpark-python')
234-
HANDLER = 'function.main'
235-
IMPORTS = ('{import_path}')
236-
"""
267+
# Adjust SQL based on parameter count
268+
if signature_info['param_count'] == 1:
269+
sql = f"""
270+
CREATE OR REPLACE FUNCTION {component_name.replace(' ', '_')}(input_data VARIANT)
271+
RETURNS VARIANT
272+
LANGUAGE PYTHON
273+
RUNTIME_VERSION=3.8
274+
PACKAGES = ('snowflake-snowpark-python')
275+
IMPORTS = ('{import_path}')
276+
HANDLER = 'function.main'
277+
"""
278+
elif signature_info['param_count'] == 2:
279+
# For two parameters
280+
sql = f"""
281+
CREATE OR REPLACE FUNCTION {component_name.replace(' ', '_')}(current_value FLOAT, previous_value FLOAT)
282+
RETURNS FLOAT
283+
LANGUAGE PYTHON
284+
RUNTIME_VERSION=3.8
285+
PACKAGES = ('snowflake-snowpark-python')
286+
IMPORTS = ('{import_path}')
287+
HANDLER = 'function.main'
288+
"""
289+
else:
290+
# Fallback to standard variant parameter
291+
sql = f"""
292+
CREATE OR REPLACE FUNCTION {component_name.replace(' ', '_')}(input_data VARIANT)
293+
RETURNS VARIANT
294+
LANGUAGE PYTHON
295+
RUNTIME_VERSION=3.8
296+
PACKAGES = ('snowflake-snowpark-python')
297+
IMPORTS = ('{import_path}')
298+
HANDLER = 'function.main'
299+
"""
237300
else: # procedure
238301
sql = f"""
239302
CREATE OR REPLACE PROCEDURE {component_name.replace(' ', '_')}()

0 commit comments

Comments
 (0)