|
13 | 13 | import pickle
|
14 | 14 | import zipfile, tarfile
|
15 | 15 | import sys
|
| 16 | +import socket |
| 17 | +import stat |
| 18 | +import hashlib |
16 | 19 | from unittest import mock, SkipTest, skipIf, skipUnless, expectedFailure
|
17 | 20 | from contextlib import contextmanager
|
18 | 21 | from glob import glob
|
@@ -5373,3 +5376,108 @@ def test_fortran_new_module_in_dep(self) -> None:
|
5373 | 5376 | output = entry['output']
|
5374 | 5377 |
|
5375 | 5378 | self.build(output, extra_args=['-j1'])
|
| 5379 | + |
| 5380 | + @skipIfNoExecutable('sshd') |
| 5381 | + @skipIfNoExecutable('sftp') |
| 5382 | + # Not tested on Windows, since there is not yet an OpenSSH server available |
| 5383 | + def test_wrap_file_sftp(self) -> None: |
| 5384 | + testdir = os.path.join(self.unit_test_dir, '130 wrap file sftp') |
| 5385 | + |
| 5386 | + def write_file(path, contents): |
| 5387 | + with open(path, 'w', encoding='utf-8') as f: |
| 5388 | + f.write(contents) |
| 5389 | + |
| 5390 | + def read_file(path): |
| 5391 | + with open(path, 'r', encoding='utf-8') as f: |
| 5392 | + return f.read() |
| 5393 | + |
| 5394 | + def generate_key(path): |
| 5395 | + subprocess.run(['ssh-keygen', '-f', path, '-N', ''], check=True) |
| 5396 | + os.chmod(path, stat.S_IREAD | stat.S_IWRITE) |
| 5397 | + with open(path + '.pub', 'r', encoding='utf-8') as f: |
| 5398 | + return f.read() |
| 5399 | + |
| 5400 | + def find_free_port(): |
| 5401 | + with socket.socket() as sock: |
| 5402 | + sock.bind(('', 0)) |
| 5403 | + return sock.getsockname()[1] |
| 5404 | + |
| 5405 | + def generate_wrap_file(tmpdir, ssh_server_port, user_key_path, host_public_key, source_hash): |
| 5406 | + os.mkdir(os.path.join(tmpdir, 'subprojects')) |
| 5407 | + write_file(os.path.join(tmpdir, 'subprojects', 'foo.wrap'), |
| 5408 | + textwrap.dedent(f'''\ |
| 5409 | + [wrap-file] |
| 5410 | + directory = foo |
| 5411 | +
|
| 5412 | + source_url = sftp://localhost:{ssh_server_port}/foo.tar.gz |
| 5413 | + source_filename = foo.tar.gz |
| 5414 | + source_hash = {source_hash} |
| 5415 | + source_hostkey = {host_public_key} |
| 5416 | + source_identityfile = {user_key_path} |
| 5417 | + ''')) |
| 5418 | + |
| 5419 | + def generate_sshd_config(sshdir, user_public_key, ssh_server_port, sftpdir): |
| 5420 | + authorized_keys_path = os.path.join(sshdir, 'authorized_keys') |
| 5421 | + write_file(authorized_keys_path, user_public_key) |
| 5422 | + sshd_config_path = os.path.join(sshdir, 'sshd_config') |
| 5423 | + write_file(sshd_config_path, |
| 5424 | + textwrap.dedent(f'''\ |
| 5425 | + ListenAddress localhost:{ssh_server_port} |
| 5426 | + PidFile "{os.path.join(sshdir, 'sshd_pid')}" |
| 5427 | + AuthorizedKeysFile "{authorized_keys_path}" |
| 5428 | + Subsystem sftp internal-sftp -d {sftpdir} |
| 5429 | + PasswordAuthentication no |
| 5430 | + ''')) |
| 5431 | + return sshd_config_path |
| 5432 | + |
| 5433 | + def start_sshd(config_path, host_key_path): |
| 5434 | + sshd_path = shutil.which('sshd') |
| 5435 | + sshd = subprocess.Popen([sshd_path, '-f', config_path, '-h', host_key_path, '-D']) |
| 5436 | + try: |
| 5437 | + sshd.wait(1) |
| 5438 | + return None |
| 5439 | + except subprocess.TimeoutExpired: |
| 5440 | + # It seems sshd started successfully |
| 5441 | + return sshd |
| 5442 | + |
| 5443 | + def hash_file(path): |
| 5444 | + h = hashlib.sha256() |
| 5445 | + with open(path, 'rb') as f: |
| 5446 | + h.update(f.read()) |
| 5447 | + return h.hexdigest() |
| 5448 | + |
| 5449 | + # sshd doesn't like the permissions of /tmp, so keep ssh related files |
| 5450 | + # in a temporary directory inside testdir. |
| 5451 | + # Cannot serve sftp from a path containing spaces, so use a tmp dir for that. |
| 5452 | + with (tempfile.TemporaryDirectory() as projdir, |
| 5453 | + tempfile.TemporaryDirectory(dir=testdir) as sshdir, |
| 5454 | + tempfile.TemporaryDirectory() as sftpdir): |
| 5455 | + shutil.copytree(os.path.join(testdir, 'top'), projdir, dirs_exist_ok=True) |
| 5456 | + shutil.copy(os.path.join(testdir, 'foo.tar.gz'), sftpdir) |
| 5457 | + source_hash = hash_file(os.path.join(testdir, 'foo.tar.gz')) |
| 5458 | + host_key_path = os.path.join(sshdir, 'host_key') |
| 5459 | + user_key_path = os.path.join(sshdir, 'user_key') |
| 5460 | + host_public_key = generate_key(host_key_path) |
| 5461 | + user_public_key = generate_key(user_key_path) |
| 5462 | + |
| 5463 | + # As there is no reliable way to avoid the port being taken between |
| 5464 | + # us finding a free port and starting the server, support a number |
| 5465 | + # of retries. |
| 5466 | + attempts = 0 |
| 5467 | + while attempts < 3: |
| 5468 | + port = find_free_port() |
| 5469 | + sshd_config_path = generate_sshd_config(sshdir, user_public_key, port, sftpdir) |
| 5470 | + sshd = start_sshd(sshd_config_path, host_key_path) |
| 5471 | + if sshd is None: |
| 5472 | + print(f'Failed to start sshd, probably due to port being taken. Trying again.') |
| 5473 | + attempts += 1 |
| 5474 | + continue |
| 5475 | + generate_wrap_file(projdir, port, user_key_path, host_public_key, source_hash) |
| 5476 | + try: |
| 5477 | + self.init(projdir) |
| 5478 | + self.build('doubler') |
| 5479 | + return |
| 5480 | + finally: |
| 5481 | + sshd.terminate() |
| 5482 | + sshd.wait() |
| 5483 | + raise self.fail(f'Failed to start sshd after {attempts} attempts.') |
0 commit comments