diff --git a/.cirrus.tasks.yml b/.cirrus.tasks.yml index 92057006c9309..f818f4e77ee26 100644 --- a/.cirrus.tasks.yml +++ b/.cirrus.tasks.yml @@ -29,7 +29,7 @@ env: MTEST_ARGS: --print-errorlogs --no-rebuild -C build PGCTLTIMEOUT: 120 # avoids spurious failures during parallel tests TEMP_CONFIG: ${CIRRUS_WORKING_DIR}/src/tools/ci/pg_ci_base.conf - PG_TEST_EXTRA: kerberos ldap ssl libpq_encryption load_balance oauth + PG_TEST_EXTRA: kerberos ldap ssl libpq_encryption load_balance oauth python # What files to preserve in case tests fail @@ -188,7 +188,7 @@ task: chown root:postgres /tmp/cores sysctl kern.corefile='/tmp/cores/%N.%P.core' setup_additional_packages_script: | - #pkg install -y ... + pkg install -y security/py-cryptography # NB: Intentionally build without -Dllvm. The freebsd image size is already # large enough to make VM startup slow, and even without llvm freebsd @@ -239,7 +239,6 @@ task: task: depends_on: SanityCheck - trigger_type: manual env: # Below are experimentally derived to be a decent choice. @@ -270,7 +269,7 @@ task: # -Duuid is not set for the NetBSD, see the comment below, above # configure_script, for more information. setup_additional_packages_script: | - #pkgin -y install ... + pkgin -y install py312-cryptography <<: *netbsd_task_template - name: OpenBSD - Meson @@ -282,7 +281,7 @@ task: UUID: -Duuid=e2fs TCL: -Dtcl_version=tcl86 setup_additional_packages_script: | - #pkg_add -I ... + pkg_add -I py3-cryptography # Always core dump to ${CORE_DUMP_DIR} set_core_dump_script: sysctl -w kern.nosuidcoredump=2 <<: *openbsd_task_template @@ -445,8 +444,9 @@ task: EOF setup_additional_packages_script: | - #apt-get update - #DEBIAN_FRONTEND=noninteractive apt-get -y install ... + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get -y install \ + python3-venv \ matrix: # SPECIAL: @@ -554,8 +554,11 @@ task: # can easily provide some here by running one of the sets of tests that # way. Newer versions of python insist on changing the LC_CTYPE away # from C, prevent that with PYTHONCOERCECLOCALE. + # XXX 32-bit Python tests are currently disabled, as the system's 64-bit + # Python modules can't link against libpq. test_world_32_script: | su postgres <<-EOF + export PG_TEST_EXTRA="${PG_TEST_EXTRA//python}" ulimit -c unlimited PYTHONCOERCECLOCALE=0 LANG=C meson test $MTEST_ARGS -C build-32 --num-processes ${TEST_JOBS} EOF diff --git a/meson.build b/meson.build index d142e3e408b38..b6c37e8e3aa96 100644 --- a/meson.build +++ b/meson.build @@ -943,10 +943,8 @@ if not libcurlopt.disabled() # libcurl and one of either epoll or kqueue. oauth_flow_supported = ( libcurl.found() - and (cc.check_header('sys/event.h', required: false, - args: test_c_args, include_directories: postgres_inc) - or cc.check_header('sys/epoll.h', required: false, - args: test_c_args, include_directories: postgres_inc)) + and (cc.has_header('sys/event.h', include_directories: postgres_inc) + or cc.has_header('sys/epoll.h', include_directories: postgres_inc)) ) if oauth_flow_supported @@ -3615,6 +3613,9 @@ else endif testwrap = files('src/tools/testwrap') +make_venv = files('src/tools/make_venv') + +checked_working_venv = false foreach test_dir : tests testwrap_base = [ @@ -3783,6 +3784,106 @@ foreach test_dir : tests ) endforeach install_suites += test_group + elif kind == 'pytest' + venv_name = test_dir['name'] + '_venv' + venv_path = meson.build_root() / venv_name + + # The Python tests require a working venv module. This is part of the + # standard library, but some platforms disable it until a separate package + # is installed. Those same platforms don't provide an easy way to check + # whether the venv command will work until the first time you try it, so + # we decide whether or not to enable these tests on the fly. + if not checked_working_venv + cmd = run_command(python, '-m', 'venv', venv_path, check: false) + + have_working_venv = (cmd.returncode() == 0) + if not have_working_venv + warning('A working Python venv module is required to run Python tests.') + endif + + checked_working_venv = true + endif + + if not have_working_venv + continue + endif + + # Make sure the temporary installation is in PATH (necessary both for + # --temp-instance and for any pip modules compiling against libpq, like + # psycopg2). + env = test_env + env.prepend('PATH', temp_install_bindir, test_dir['bd']) + + foreach name, value : t.get('env', {}) + env.set(name, value) + endforeach + + reqs = files(t['requirements']) + test('install_' + venv_name, + python, + args: [ make_venv, '--requirements', reqs, venv_path ], + env: env, + priority: setup_tests_priority - 1, # must run after tmp_install + is_parallel: false, + suite: ['setup'], + timeout: 60, # 30s is too short for the cryptography package compile + ) + + test_group = test_dir['name'] + test_output = test_result_dir / test_group / kind + test_kwargs = { + #'protocol': 'tap', + 'suite': test_group, + 'timeout': 1000, + 'depends': test_deps, + 'env': env, + } + t.get('test_kwargs', {}) + + if fs.is_dir(venv_path / 'Scripts') + # Windows virtualenv layout + pytest = venv_path / 'Scripts' / 'py.test' + else + pytest = venv_path / 'bin' / 'py.test' + endif + + test_command = [ + pytest, + # Avoid running these tests against an existing database. + '--temp-instance', test_output / 'data', + + # FIXME pytest-tap's stream feature accidentally suppresses errors that + # are critical for debugging: + # https://github.com/python-tap/pytest-tap/issues/30 + # Don't use the meson TAP protocol for now... + #'--tap-stream', + ] + + foreach pyt : t['tests'] + # Similarly to TAP, strip ./ and .py to make the names prettier + pyt_p = pyt + if pyt_p.startswith('./') + pyt_p = pyt_p.split('./')[1] + endif + if pyt_p.endswith('.py') + pyt_p = fs.stem(pyt_p) + endif + + testwrap_pytest = testwrap_base + [ + '--testgroup', test_group, + '--testname', pyt_p, + '--skip-without-extra', 'python', + ] + + test(test_group / pyt_p, + python, + kwargs: test_kwargs, + args: testwrap_pytest + [ + '--', test_command, + test_dir['sd'] / pyt, + ], + ) + endforeach + install_suites += test_group else error('unknown kind @0@ of test in @1@'.format(kind, test_dir['sd'])) endif diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c index 27f7af7be0024..46c653efff55c 100644 --- a/src/backend/libpq/auth-oauth.c +++ b/src/backend/libpq/auth-oauth.c @@ -511,6 +511,7 @@ generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen) initStringInfo(&buf); /* + * TODO * Escaping the string here is belt-and-suspenders defensive programming * since escapable characters aren't valid in either the issuer URI or the * scope list, but the HBA doesn't enforce that yet. @@ -699,6 +700,7 @@ validate(Port *port, const char *auth) /* Make sure the validator authenticated the user. */ if (ret->authn_id == NULL || ret->authn_id[0] == '\0') { + /* TODO: test logdetail; reduce message duplication elsewhere */ ereport(LOG, errmsg("OAuth bearer authentication failed for user \"%s\"", port->user_name), diff --git a/src/interfaces/libpq-oauth/Makefile b/src/interfaces/libpq-oauth/Makefile index 270fc0cf2d9d9..9da8e4b71435a 100644 --- a/src/interfaces/libpq-oauth/Makefile +++ b/src/interfaces/libpq-oauth/Makefile @@ -79,5 +79,19 @@ uninstall: rm -f '$(DESTDIR)$(libdir)/$(stlib)' rm -f '$(DESTDIR)$(libdir)/$(shlib)' +.PHONY: all-tests +all-tests: oauth_tests$(X) + +oauth_tests$(X): test-oauth-curl.o oauth-utils.o $(WIN32RES) | submake-libpgport submake-libpq + $(CC) $(CFLAGS) $^ $(LDFLAGS) $(LDFLAGS_EX) $(SHLIB_LINK) -o $@ + +check: all-tests + $(prove_check) + +installcheck: all-tests + $(prove_installcheck) + clean distclean: clean-lib rm -f $(OBJS) $(OBJS_STATIC) $(OBJS_SHLIB) + rm -f test-oauth-curl.o oauth_tests$(X) + rm -rf tmp_check diff --git a/src/interfaces/libpq-oauth/meson.build b/src/interfaces/libpq-oauth/meson.build index df064c59a4070..505e1671b8637 100644 --- a/src/interfaces/libpq-oauth/meson.build +++ b/src/interfaces/libpq-oauth/meson.build @@ -47,3 +47,38 @@ libpq_oauth_so = shared_module(libpq_oauth_name, link_args: export_fmt.format(export_file.full_path()), kwargs: default_lib_args, ) + +libpq_oauth_test_deps = [] + +oauth_test_sources = files('test-oauth-curl.c') + libpq_oauth_so_sources + +if host_system == 'windows' + oauth_test_sources += rc_bin_gen.process(win32ver_rc, extra_args: [ + '--NAME', 'oauth_tests', + '--FILEDESC', 'OAuth unit test program',]) +endif + +libpq_oauth_test_deps += executable('oauth_tests', + oauth_test_sources, + dependencies: [frontend_shlib_code, libpq, libpq_oauth_deps], + kwargs: default_bin_args + { + 'c_args': default_bin_args.get('c_args', []) + libpq_oauth_so_c_args, + 'c_pch': pch_postgres_fe_h, + 'include_directories': [libpq_inc, postgres_inc], + 'install': false, + } +) + +testprep_targets += libpq_oauth_test_deps + +tests += { + 'name': 'libpq-oauth', + 'sd': meson.current_source_dir(), + 'bd': meson.current_build_dir(), + 'tap': { + 'tests': [ + 't/001_oauth.pl', + ], + 'deps': libpq_oauth_test_deps, + }, +} diff --git a/src/interfaces/libpq-oauth/oauth-curl.c b/src/interfaces/libpq-oauth/oauth-curl.c index dba9a684fa8a5..41b6bc584f366 100644 --- a/src/interfaces/libpq-oauth/oauth-curl.c +++ b/src/interfaces/libpq-oauth/oauth-curl.c @@ -278,6 +278,11 @@ struct async_ctx bool user_prompted; /* have we already sent the authz prompt? */ bool used_basic_auth; /* did we send a client secret? */ bool debugging; /* can we give unsafe developer assistance? */ + int dbg_num_calls; /* (debug mode) how many times were we called? */ + +#if defined(HAVE_SYS_EVENT_H) + int nevents; /* how many events are we waiting on? */ +#endif }; /* @@ -1289,43 +1294,107 @@ register_socket(CURL *curl, curl_socket_t socket, int what, void *ctx, return -1; } + if (actx->debugging) + fprintf(stderr, "%s fd %d%s\n", + (op == EPOLL_CTL_DEL ? "Removed" + : (op == EPOLL_CTL_ADD ? "Added" : "Updated")), + socket, + (what == CURL_POLL_REMOVE ? "" + : (what == CURL_POLL_IN ? " (read)" + : (what == CURL_POLL_OUT ? " (write)" + : " (read/write)")))); + return 0; #elif defined(HAVE_SYS_EVENT_H) - struct kevent ev[2] = {0}; + struct kevent ev[2]; struct kevent ev_out[2]; struct timespec timeout = {0}; - int nev = 0; + int nev; int res; + /* + * First, any existing registrations for this socket need to be removed, + * both to track the outstanding number of events, and to ensure that + * we're not woken up for things that Curl no longer cares about. + * + * ENOENT is okay, but we have to track how many we get, so use + * EV_RECEIPT. + */ + nev = 0; + EV_SET(&ev[nev], socket, EVFILT_READ, EV_DELETE | EV_RECEIPT, 0, 0, 0); + nev++; + EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_DELETE | EV_RECEIPT, 0, 0, 0); + nev++; + + Assert(nev <= lengthof(ev)); + Assert(nev <= lengthof(ev_out)); + + res = kevent(actx->mux, ev, nev, ev_out, nev, &timeout); + if (res < 0) + { + actx_error(actx, "could not delete from kqueue: %m"); + return -1; + } + + /* + * We can't use the simple errno version of kevent, because we need to + * skip over ENOENT while still allowing a second change to be processed. + * So we need a longer-form error checking loop. + */ + for (int i = 0; i < res; ++i) + { + /* + * EV_RECEIPT should guarantee one EV_ERROR result for every change, + * whether successful or not. Failed entries contain a non-zero errno + * in the data field. + */ + Assert(ev_out[i].flags & EV_ERROR); + + errno = ev_out[i].data; + if (!errno) + { + /* Successfully removed; update the event count. */ + Assert(actx->nevents > 0); + actx->nevents--; + } + else if (errno != ENOENT) + { + actx_error(actx, "could not delete from kqueue: %m"); + return -1; + } + } + + /* If we're only removing registrations, we're done. */ + if (what == CURL_POLL_REMOVE) + return 0; + + /* + * Now add the new filters. This is more straightfoward than deletion. + * + * Combining this kevent() call with the one above seems like it should be + * theoretically possible, but beware that not all BSDs keep the original + * event flags when using EV_RECEIPT, so it's tricky to figure out which + * operations succeeded. For now we keep the deletions and the additions + * separate. + */ + nev = 0; + switch (what) { case CURL_POLL_IN: - EV_SET(&ev[nev], socket, EVFILT_READ, EV_ADD | EV_RECEIPT, 0, 0, 0); + EV_SET(&ev[nev], socket, EVFILT_READ, EV_ADD, 0, 0, 0); nev++; break; case CURL_POLL_OUT: - EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_ADD | EV_RECEIPT, 0, 0, 0); + EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_ADD, 0, 0, 0); nev++; break; case CURL_POLL_INOUT: - EV_SET(&ev[nev], socket, EVFILT_READ, EV_ADD | EV_RECEIPT, 0, 0, 0); - nev++; - EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_ADD | EV_RECEIPT, 0, 0, 0); - nev++; - break; - - case CURL_POLL_REMOVE: - - /* - * We don't know which of these is currently registered, perhaps - * both, so we try to remove both. This means we need to tolerate - * ENOENT below. - */ - EV_SET(&ev[nev], socket, EVFILT_READ, EV_DELETE | EV_RECEIPT, 0, 0, 0); + EV_SET(&ev[nev], socket, EVFILT_READ, EV_ADD, 0, 0, 0); nev++; - EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_DELETE | EV_RECEIPT, 0, 0, 0); + EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_ADD, 0, 0, 0); nev++; break; @@ -1334,41 +1403,26 @@ register_socket(CURL *curl, curl_socket_t socket, int what, void *ctx, return -1; } - res = kevent(actx->mux, ev, nev, ev_out, lengthof(ev_out), &timeout); + Assert(nev <= lengthof(ev)); + + res = kevent(actx->mux, ev, nev, NULL, 0, NULL); if (res < 0) { actx_error(actx, "could not modify kqueue: %m"); return -1; } - /* - * We can't use the simple errno version of kevent, because we need to - * skip over ENOENT while still allowing a second change to be processed. - * So we need a longer-form error checking loop. - */ - for (int i = 0; i < res; ++i) - { - /* - * EV_RECEIPT should guarantee one EV_ERROR result for every change, - * whether successful or not. Failed entries contain a non-zero errno - * in the data field. - */ - Assert(ev_out[i].flags & EV_ERROR); + /* Update the event count, and we're done. */ + actx->nevents += nev; - errno = ev_out[i].data; - if (errno && errno != ENOENT) - { - switch (what) - { - case CURL_POLL_REMOVE: - actx_error(actx, "could not delete from kqueue: %m"); - break; - default: - actx_error(actx, "could not add to kqueue: %m"); - } - return -1; - } - } + if (actx->debugging) + fprintf(stderr, "%s fd %d%s\n", + (what == CURL_POLL_REMOVE ? "Removed" : "Updated"), + socket, + (what == CURL_POLL_REMOVE ? "" + : (what == CURL_POLL_IN ? " (read)" + : (what == CURL_POLL_OUT ? " (write)" + : " (read/write)")))); return 0; #else @@ -1420,6 +1474,11 @@ set_timer(struct async_ctx *actx, long timeout) return false; } + if (actx->debugging) + fprintf(stderr, "%s timer: %ld ms\n", + (timeout < 0 ? "Removed" : "Set"), + timeout); + return true; #elif defined(HAVE_SYS_EVENT_H) struct kevent ev; @@ -1441,7 +1500,8 @@ set_timer(struct async_ctx *actx, long timeout) * macOS.) * * If there was no previous timer set, the kevent calls will result in - * ENOENT, which is fine. + * ENOENT, which is fine. (We don't track actx->nevents for this case; + * instead, drain_socket_events() just assumes a timer could be set.) */ EV_SET(&ev, 1, EVFILT_TIMER, EV_DELETE, 0, 0, 0); if (kevent(actx->timerfd, &ev, 1, NULL, 0, NULL) < 0 && errno != ENOENT) @@ -1459,7 +1519,12 @@ set_timer(struct async_ctx *actx, long timeout) /* If we're not adding a timer, we're done. */ if (timeout < 0) + { + if (actx->debugging) + fprintf(stderr, "Removed timer: %ld ms\n", timeout); + return true; + } EV_SET(&ev, 1, EVFILT_TIMER, (EV_ADD | EV_ONESHOT), 0, timeout, 0); if (kevent(actx->timerfd, &ev, 1, NULL, 0, NULL) < 0) @@ -1475,6 +1540,9 @@ set_timer(struct async_ctx *actx, long timeout) return false; } + if (actx->debugging) + fprintf(stderr, "Added timer: %ld ms\n", timeout); + return true; #else #error set_timer is not implemented on this platform @@ -1483,49 +1551,94 @@ set_timer(struct async_ctx *actx, long timeout) /* * Returns 1 if the timeout in the multiplexer set has expired since the last - * call to set_timer(), 0 if the timer is still running, or -1 (with an - * actx_error() report) if the timer cannot be queried. + * call to set_timer(), 0 if the timer is either still running or disarmed, or + * -1 (with an actx_error() report) if the timer cannot be queried. */ static int timer_expired(struct async_ctx *actx) { -#if defined(HAVE_SYS_EPOLL_H) - struct itimerspec spec = {0}; +#if defined(HAVE_SYS_EPOLL_H) || defined(HAVE_SYS_EVENT_H) + int res; - if (timerfd_gettime(actx->timerfd, &spec) < 0) + /* Is the timer ready? */ + res = PQsocketPoll(actx->timerfd, 1 /* forRead */ , 0, 0); + if (res < 0) { - actx_error(actx, "getting timerfd value: %m"); + actx_error(actx, "checking timer expiration: %m"); return -1; } + if (actx->debugging) + fprintf(stderr, "timer has %sexpired\n", (res > 0 ? "" : "not ")); + + return (res > 0); +#else +#error timer_expired is not implemented on this platform +#endif +} + +static bool +drain_socket_events(struct async_ctx *actx) +{ +#if defined(HAVE_SYS_EPOLL_H) + /* The epoll implementation doesn't need to drain pending events. */ + return true; +#elif defined(HAVE_SYS_EVENT_H) + struct timespec timeout = {0}; + struct kevent *drain; + int drain_len; + /* - * This implementation assumes we're using single-shot timers. If you - * change to using intervals, you'll need to reimplement this function - * too, possibly with the read() or select() interfaces for timerfd. + * register_socket() keeps actx->nevents updated with the number of + * outstanding event filters. We don't track the registration of the + * timer; we just assume one could be registered here. */ - Assert(spec.it_interval.tv_sec == 0 - && spec.it_interval.tv_nsec == 0); + drain_len = actx->nevents + 1; - /* If the remaining time to expiration is zero, we're done. */ - return (spec.it_value.tv_sec == 0 - && spec.it_value.tv_nsec == 0); -#elif defined(HAVE_SYS_EVENT_H) - int res; + drain = malloc(sizeof(*drain) * drain_len); + if (!drain) + { + actx_error(actx, "out of memory"); + return false; + } - /* Is the timer queue ready? */ - res = PQsocketPoll(actx->timerfd, 1 /* forRead */ , 0, 0); - if (res < 0) + /* Discard all pending events. */ + if (kevent(actx->mux, NULL, 0, drain, drain_len, &timeout) < 0) { - actx_error(actx, "checking kqueue for timeout: %m"); - return -1; + actx_error(actx, "could not drain kqueue: %m"); + free(drain); + return false; } - return (res > 0); + free(drain); + return true; #else -#error timer_expired is not implemented on this platform +#error drain_socket_events is not implemented on this platform #endif } +static bool +drain_timer_events(struct async_ctx *actx, bool *expired) +{ + int res; + + res = timer_expired(actx); + if (res < 0) + return false; + + if (res > 0) + { + /* Timer is expired. Disable it to clear the signal in the mux. */ + if (!set_timer(actx, -1)) + return false; + } + + if (expired) + *expired = (res > 0); + + return true; +} + /* * Adds or removes timeouts from the multiplexer set, as directed by the * libcurl multi handle. @@ -1867,6 +1980,17 @@ drive_request(struct async_ctx *actx) int msgs_left; bool done; + if (actx->debugging) + fprintf(stderr, "In drive_request\n"); + + /* + if (timer_expired(actx)) + { + if (!set_timer(actx, -1)) + return PGRES_POLLING_FAILED; + } + */ + if (actx->running) { /*--- @@ -2932,6 +3056,8 @@ PostgresPollingStatusType pg_fe_run_oauth_flow(PGconn *conn) { PostgresPollingStatusType result; + fe_oauth_state *state = conn_sasl_state(conn); + struct async_ctx *actx = state->async_ctx; #ifndef WIN32 sigset_t osigset; bool sigpipe_pending; @@ -2960,6 +3086,14 @@ pg_fe_run_oauth_flow(PGconn *conn) result = pg_fe_run_oauth_flow_impl(conn); + if (actx && actx->debugging) + { + actx->dbg_num_calls++; + if (result == PGRES_POLLING_OK || result == PGRES_POLLING_FAILED) + fprintf(stderr, "[libpq] total number of polls: %d\n", + actx->dbg_num_calls); + } + #ifndef WIN32 if (masked) { diff --git a/src/interfaces/libpq-oauth/t/001_oauth.pl b/src/interfaces/libpq-oauth/t/001_oauth.pl new file mode 100644 index 0000000000000..e769856c2c9c9 --- /dev/null +++ b/src/interfaces/libpq-oauth/t/001_oauth.pl @@ -0,0 +1,24 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group +use strict; +use warnings FATAL => 'all'; + +use PostgreSQL::Test::Utils; +use Test::More; + +# Defer entirely to the oauth_tests executable. stdout/err is routed through +# Test::More so that our logging infrastructure can handle it correctly. Using +# IPC::Run::new_chunker seems to help interleave the two streams a little better +# than without. +# +# TODO: prove can also deal with native executables itself, which we could +# probably make use of via PROVE_TESTS on the Makefile side. But the Meson setup +# calls Perl directly, which would require more code to work around... and +# there's still the matter of logging. +my $builder = Test::More->builder; +my $out = $builder->output; +my $err = $builder->failure_output; + +IPC::Run::run ['oauth_tests'], + '>', IPC::Run::new_chunker, sub { print {$out} $_[0] }, + '2>', IPC::Run::new_chunker, sub { print {$err} $_[0] } + or die "oauth_tests returned $?"; diff --git a/src/interfaces/libpq-oauth/test-oauth-curl.c b/src/interfaces/libpq-oauth/test-oauth-curl.c new file mode 100644 index 0000000000000..3da87a89309ea --- /dev/null +++ b/src/interfaces/libpq-oauth/test-oauth-curl.c @@ -0,0 +1,474 @@ +/* + * test-oauth-curl.c + * + * A unit test driver for libpq-oauth. This #includes oauth-curl.c, which lets + * the tests reference static functions and other internals. + * + * USE_ASSERT_CHECKING is required, to make it easy for tests to wrap + * must-succeed code as part of test setup. + * + * Copyright (c) 2025, PostgreSQL Global Development Group + */ + +#include "oauth-curl.c" + +#include + +#ifdef USE_ASSERT_CHECKING + +/* + * TAP Helpers + */ + +static int num_tests = 0; + +/* + * Reports ok/not ok to the TAP stream on stdout. + */ +#define ok(OK, TEST) \ + ok_impl(OK, TEST, #OK, __FILE__, __LINE__) + +static bool +ok_impl(bool ok, const char *test, const char *teststr, const char *file, int line) +{ + printf("%sok %d - %s\n", ok ? "" : "not ", ++num_tests, test); + + if (!ok) + { + printf("# at %s:%d:\n", file, line); + printf("# expression is false: %s\n", teststr); + } + + return ok; +} + +/* + * Like ok(this == that), but with more diagnostics on failure. + * + * Only works on ints, but luckily that's all we need here. Note that the much + * simpler-looking macro implementation + * + * is_diag(ok(THIS == THAT, TEST), THIS, #THIS, THAT, #THAT) + * + * suffers from multiple evaluation of the macro arguments... + */ +#define is(THIS, THAT, TEST) \ + do { \ + int this_ = (THIS), \ + that_ = (THAT); \ + is_diag( \ + ok_impl(this_ == that_, TEST, #THIS " == " #THAT, __FILE__, __LINE__), \ + this_, #THIS, that_, #THAT \ + ); \ + } while (0) + +static bool +is_diag(bool ok, int this, const char *thisstr, int that, const char *thatstr) +{ + if (!ok) + printf("# %s = %d; %s = %d\n", thisstr, this, thatstr, that); + + return ok; +} + +/* + * Utilities + */ + +/* + * Creates a partially-initialized async_ctx for the purposes of testing. Free + * with free_test_actx(). + */ +static struct async_ctx * +init_test_actx(void) +{ + struct async_ctx *actx; + + actx = calloc(1, sizeof(*actx)); + Assert(actx); + + actx->mux = PGINVALID_SOCKET; + actx->timerfd = -1; + actx->debugging = true; + + initPQExpBuffer(&actx->errbuf); + + Assert(setup_multiplexer(actx)); + + return actx; +} + +static void +free_test_actx(struct async_ctx *actx) +{ + termPQExpBuffer(&actx->errbuf); + + if (actx->mux != PGINVALID_SOCKET) + close(actx->mux); + if (actx->timerfd >= 0) + close(actx->timerfd); + + free(actx); +} + +static char dummy_buf[4 * 1024]; /* for fill_pipe/drain_pipe */ + +/* + * Writes to the write side of a pipe until it won't take any more data. Returns + * the amount written. + */ +static ssize_t +fill_pipe(int fd) +{ + int mode; + ssize_t written = 0; + + /* Don't block. */ + Assert((mode = fcntl(fd, F_GETFL)) != -1); + Assert(fcntl(fd, F_SETFL, mode | O_NONBLOCK) == 0); + + while (true) + { + ssize_t w; + + w = write(fd, dummy_buf, sizeof(dummy_buf)); + if (w < 0) + { + if (errno != EAGAIN && errno != EWOULDBLOCK) + { + perror("write to pipe"); + written = -1; + } + break; + } + + written += w; + } + + /* Reset the descriptor flags. */ + Assert(fcntl(fd, F_SETFD, mode) == 0); + + return written; +} + +/* + * Drains the requested amount of data from the read side of a pipe. + */ +static bool +drain_pipe(int fd, ssize_t n) +{ + Assert(n > 0); + + while (n) + { + size_t to_read = (n <= sizeof(dummy_buf)) ? n : sizeof(dummy_buf); + ssize_t drained; + + drained = read(fd, dummy_buf, to_read); + if (drained < 0) + { + perror("read from pipe"); + return false; + } + + n -= drained; + } + + return true; +} + +/* + * Tests whether the multiplexer is marked ready by the deadline. This is a + * macro so that file/line information makes sense during failures. + * + * NB: our current multiplexer implementations (epoll/kqueue) are *readable* + * when the underlying libcurl sockets are *writable*. This behavior is pinned + * here to record that expectation, but it's not a required part of the API. If + * you've added a new implementation that doesn't have that behavior, feel free + * to modify this test. + */ +#define mux_is_ready(MUX, DEADLINE, TEST) \ + do { \ + int res_ = PQsocketPoll(MUX, 1, 0, DEADLINE); \ + Assert(res_ != -1); \ + ok(res_ > 0, "multiplexer is ready " TEST); \ + } while (0) + +/* + * The opposite of mux_is_ready(). + */ +#define mux_is_not_ready(MUX, TEST) \ + do { \ + int res_ = PQsocketPoll(MUX, 1, 0, 0); \ + Assert(res_ != -1); \ + is(res_, 0, "multiplexer is not ready " TEST); \ + } while (0) + +/* + * Test Suites + */ + +/* Per-suite timeout. Set via the PG_TEST_TIMEOUT_DEFAULT envvar. */ +static pg_usec_time_t timeout_us = 180 * 1000 * 1000; + +static void +test_set_timer(void) +{ + struct async_ctx *actx = init_test_actx(); + const pg_usec_time_t deadline = PQgetCurrentTimeUSec() + timeout_us; + + printf("# test_set_timer\n"); + + /* A zero-duration timer should result in a near-immediate ready signal. */ + Assert(set_timer(actx, 0)); + mux_is_ready(actx->mux, deadline, "when timer expires"); + is(timer_expired(actx), 1, "timer_expired() returns 1 when timer expires"); + + /* Resetting the timer far in the future should unset the ready signal. */ + Assert(set_timer(actx, INT_MAX)); + mux_is_not_ready(actx->mux, "when timer is reset to the future"); + is(timer_expired(actx), 0, "timer_expired() returns 0 with unexpired timer"); + + /* Setting another zero-duration timer should override the previous one. */ + Assert(set_timer(actx, 0)); + mux_is_ready(actx->mux, deadline, "when timer is re-expired"); + is(timer_expired(actx), 1, "timer_expired() returns 1 when timer is re-expired"); + + /* And disabling that timer should once again unset the ready signal. */ + Assert(set_timer(actx, -1)); + mux_is_not_ready(actx->mux, "when timer is unset"); + is(timer_expired(actx), 0, "timer_expired() returns 0 when timer is unset"); + + { + bool expired; + + /* Make sure drain_timer_events() functions correctly as well. */ + Assert(set_timer(actx, 0)); + mux_is_ready(actx->mux, deadline, "when timer is re-expired (drain_timer_events)"); + + Assert(drain_timer_events(actx, &expired)); + mux_is_not_ready(actx->mux, "when timer is drained after expiring"); + is(expired, 1, "drain_timer_events() reports expiration"); + is(timer_expired(actx), 0, "timer_expired() returns 0 after timer is drained"); + + /* A second drain should do nothing. */ + Assert(drain_timer_events(actx, &expired)); + mux_is_not_ready(actx->mux, "when timer is drained a second time"); + is(expired, 0, "drain_timer_events() reports no expiration"); + is(timer_expired(actx), 0, "timer_expired() still returns 0"); + } + + free_test_actx(actx); +} + +static void +test_register_socket(void) +{ + struct async_ctx *actx = init_test_actx(); + int pipefd[2]; + int rfd, + wfd; + bool bidirectional; + + /* Create a local pipe for communication. */ + Assert(pipe(pipefd) == 0); + rfd = pipefd[0]; + wfd = pipefd[1]; + + /* + * Some platforms (FreeBSD) implement bidirectional pipes, affecting the + * behavior of some of these tests. Store that knowledge for later. + */ + bidirectional = PQsocketPoll(rfd /* read */ , 0, 1 /* write */ , 0) > 0; + + /* + * This suite runs twice -- once using CURL_POLL_IN/CURL_POLL_OUT for + * read/write operations, respectively, and once using CURL_POLL_INOUT for + * both sides. + */ + for (int inout = 0; inout < 2; inout++) + { + const int in_event = inout ? CURL_POLL_INOUT : CURL_POLL_IN; + const int out_event = inout ? CURL_POLL_INOUT : CURL_POLL_OUT; + const pg_usec_time_t deadline = PQgetCurrentTimeUSec() + timeout_us; + size_t bidi_pipe_size; + + printf("# test_register_socket %s\n", inout ? "(INOUT)" : ""); + + /* + * At the start of the test, the read side should be blocked and the + * write side should be open. (There's a mistake at the end of this + * loop otherwise.) + */ + Assert(PQsocketPoll(rfd, 1, 0, 0) == 0); + Assert(PQsocketPoll(wfd, 0, 1, 0) > 0); + + /* + * For bidirectional systems, emulate unidirectional behavior here by + * filling up the "read side" of the pipe. + */ + if (bidirectional) + Assert((bidi_pipe_size = fill_pipe(rfd)) > 0); + + /* Listen on the read side. The multiplexer shouldn't be ready yet. */ + Assert(register_socket(NULL, rfd, in_event, actx, NULL) == 0); + mux_is_not_ready(actx->mux, "when fd is not readable"); + + /* Writing to the pipe should result in a read-ready multiplexer. */ + Assert(write(wfd, "x", 1) == 1); + mux_is_ready(actx->mux, deadline, "when fd is readable"); + + /* + * Update the registration to wait on write events instead. The + * multiplexer should be unset. + */ + Assert(register_socket(NULL, rfd, CURL_POLL_OUT, actx, NULL) == 0); + mux_is_not_ready(actx->mux, "when waiting for writes on readable fd"); + + /* Re-register for read events. */ + Assert(register_socket(NULL, rfd, in_event, actx, NULL) == 0); + mux_is_ready(actx->mux, deadline, "when waiting for reads again"); + + /* Stop listening. The multiplexer should be unset. */ + Assert(register_socket(NULL, rfd, CURL_POLL_REMOVE, actx, NULL) == 0); + mux_is_not_ready(actx->mux, "when readable fd is removed"); + + /* Listen again. */ + Assert(register_socket(NULL, rfd, in_event, actx, NULL) == 0); + mux_is_ready(actx->mux, deadline, "when readable fd is re-added"); + + /* + * Draining the pipe should unset the multiplexer again, once the old + * event is drained. + */ + Assert(drain_pipe(rfd, 1)); + Assert(drain_socket_events(actx)); + mux_is_not_ready(actx->mux, "when fd is drained"); + + /* Undo any unidirectional emulation. */ + if (bidirectional) + Assert(drain_pipe(wfd, bidi_pipe_size)); + + /* Listen on the write side. An empty buffer should be writable. */ + Assert(register_socket(NULL, rfd, CURL_POLL_REMOVE, actx, NULL) == 0); + Assert(register_socket(NULL, wfd, out_event, actx, NULL) == 0); + mux_is_ready(actx->mux, deadline, "when fd is writable"); + + /* As above, wait on read events instead. */ + Assert(register_socket(NULL, wfd, CURL_POLL_IN, actx, NULL) == 0); + mux_is_not_ready(actx->mux, "when waiting for reads on writable fd"); + + /* Re-register for write events. */ + Assert(register_socket(NULL, wfd, out_event, actx, NULL) == 0); + mux_is_ready(actx->mux, deadline, "when waiting for writes again"); + + { + ssize_t written; + + /* + * Fill the pipe. Once the old writable event is drained, the mux + * should not be ready. + */ + Assert((written = fill_pipe(wfd)) > 0); + printf("# pipe buffer is full at %zd bytes\n", written); + + Assert(drain_socket_events(actx)); + mux_is_not_ready(actx->mux, "when fd buffer is full"); + + /* Drain the pipe again. */ + Assert(drain_pipe(rfd, written)); + mux_is_ready(actx->mux, deadline, "when fd buffer is drained"); + } + + /* Stop listening. */ + Assert(register_socket(NULL, wfd, CURL_POLL_REMOVE, actx, NULL) == 0); + mux_is_not_ready(actx->mux, "when fd is removed"); + + /* Make sure an expired timer doesn't interfere with event draining. */ + { + /* Make the rfd appear unidirectional if necessary. */ + if (bidirectional) + Assert((bidi_pipe_size = fill_pipe(rfd)) > 0); + + /* Set the timer and wait for it to expire. */ + Assert(set_timer(actx, 0)); + Assert(PQsocketPoll(actx->timerfd, 1, 0, deadline) > 0); + is(timer_expired(actx), 1, "timer is expired"); + + /* Register for read events and make the fd readable. */ + Assert(register_socket(NULL, rfd, in_event, actx, NULL) == 0); + Assert(write(wfd, "x", 1) == 1); + mux_is_ready(actx->mux, deadline, "when fd is readable and timer expired"); + + /* + * Draining the pipe should unset the multiplexer again, once the + * old event is drained and the timer is reset. + * + * Order matters to avoid false negatives. First drain the socket, + * then unset the timer. We're trying to catch the case where the + * pending timer expiration event takes the place of one of the + * socket events we're attempting to drain. + */ + Assert(drain_pipe(rfd, 1)); + Assert(drain_socket_events(actx)); + Assert(set_timer(actx, -1)); + + is(timer_expired(actx), 0, "timer is no longer expired"); + mux_is_not_ready(actx->mux, "when fd is drained and timer reset"); + + /* Stop listening. */ + Assert(register_socket(NULL, rfd, CURL_POLL_REMOVE, actx, NULL) == 0); + + /* Undo any unidirectional emulation. */ + if (bidirectional) + Assert(drain_pipe(wfd, bidi_pipe_size)); + } + } + + close(rfd); + close(wfd); + free_test_actx(actx); +} + +int +main(int argc, char *argv[]) +{ + const char *timeout; + + /* Grab the default timeout. */ + timeout = getenv("PG_TEST_TIMEOUT_DEFAULT"); + if (timeout) + { + int timeout_s = atoi(timeout); + + if (timeout_s > 0) + timeout_us = timeout_s * 1000 * 1000; + } + + /* + * Set up line buffering for our output, to let stderr interleave in the + * log files. + */ + setvbuf(stdout, NULL, PG_IOLBF, 0); + + test_set_timer(); + test_register_socket(); + + printf("1..%d\n", num_tests); + return 0; +} + +#else /* !USE_ASSERT_CHECKING */ + +/* + * Skip the test suite when we don't have assertions. + */ +int +main(int argc, char *argv[]) +{ + printf("1..0 # skip: cassert is not enabled\n"); + + return 0; +} + +#endif /* USE_ASSERT_CHECKING */ diff --git a/src/test/meson.build b/src/test/meson.build index ccc31d6a86a1b..236057cd99e91 100644 --- a/src/test/meson.build +++ b/src/test/meson.build @@ -8,6 +8,7 @@ subdir('postmaster') subdir('recovery') subdir('subscription') subdir('modules') +subdir('python') if ssl.found() subdir('ssl') diff --git a/src/test/modules/oauth_validator/t/001_server.pl b/src/test/modules/oauth_validator/t/001_server.pl index 41672ebd5c6dc..c0dafb8be7642 100644 --- a/src/test/modules/oauth_validator/t/001_server.pl +++ b/src/test/modules/oauth_validator/t/001_server.pl @@ -418,6 +418,35 @@ sub connstr qr/failed to obtain access token: mutual TLS required for client \(invalid_client\)/ ); +# Count the number of calls to the internal flow when multiple retries are +# triggered. The exact number depends on many things -- the TCP stack, the +# version of Curl in use, random chance -- but a ridiculously high number +# suggests something is wrong with our ability to clear multiplexer events after +# they're no longer applicable. +my ($ret, $stdout, $stderr) = $node->psql( + 'postgres', + "SELECT 'connected for call count'", + extra_params => ['-w'], + connstr => connstr(stage => 'token', retries => 2), + on_error_stop => 0); + +is($ret, 0, "call count connection succeeds"); +like( + $stderr, + qr@Visit https://example\.com/ and enter the code: postgresuser@, + "call count: stderr matches"); + +my $count_pattern = qr/\[libpq\] total number of polls: (\d+)/; +if (like($stderr, $count_pattern, "call count: count is printed")) +{ + # For reference, a typical flow with two retries might take between 5-15 + # calls to the client implementation. And while this will probably continue + # to change across OSes and Curl updates, we're likely in trouble if we see + # hundreds or thousands of calls. + $stderr =~ $count_pattern; + cmp_ok($1, '<', 100, "call count is reasonably small"); +} + # Stress test: make sure our builtin flow operates correctly even if the client # application isn't respecting PGRES_POLLING_READING/WRITING signals returned # from PQconnectPoll(). @@ -428,7 +457,7 @@ sub connstr connstr(stage => 'all', retries => 1, interval => 1)); note "running '" . join("' '", @cmd) . "'"; -my ($stdout, $stderr) = run_command(\@cmd); +($stdout, $stderr) = run_command(\@cmd); like($stdout, qr/connection succeeded/, "stress-async: stdout matches"); unlike( diff --git a/src/test/python/.gitignore b/src/test/python/.gitignore new file mode 100644 index 0000000000000..0e8f027b2ec2c --- /dev/null +++ b/src/test/python/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +/venv/ diff --git a/src/test/python/Makefile b/src/test/python/Makefile new file mode 100644 index 0000000000000..b0695b6287e31 --- /dev/null +++ b/src/test/python/Makefile @@ -0,0 +1,38 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +# Only Python 3 is supported, but if it's named something different on your +# system you can override it with the PYTHON3 variable. +PYTHON3 := python3 + +# All dependencies are placed into this directory. The default is .gitignored +# for you, but you can override it if you'd like. +VENV := ./venv + +override VBIN := $(VENV)/bin +override PIP := $(VBIN)/pip +override PYTEST := $(VBIN)/py.test +override ISORT := $(VBIN)/isort +override BLACK := $(VBIN)/black + +.PHONY: installcheck indent + +installcheck: $(PYTEST) + $(PYTEST) -v -rs + +indent: $(ISORT) $(BLACK) + $(ISORT) --profile black *.py client/*.py server/*.py + $(BLACK) *.py client/*.py server/*.py + +$(PYTEST) $(ISORT) $(BLACK) &: requirements.txt | $(PIP) + $(PIP) install --force-reinstall -r $< + +$(PIP): + $(PYTHON3) -m venv $(VENV) + +# A convenience recipe to rebuild psycopg2 against the local libpq. +.PHONY: rebuild-psycopg2 +rebuild-psycopg2: | $(PIP) + $(PIP) install --force-reinstall --no-binary :all: $(shell grep psycopg2 requirements.txt) diff --git a/src/test/python/README b/src/test/python/README new file mode 100644 index 0000000000000..acf339a589915 --- /dev/null +++ b/src/test/python/README @@ -0,0 +1,66 @@ +A test suite for exercising both the libpq client and the server backend at the +protocol level, based on pytest and Construct. + +WARNING! This suite takes superuser-level control of the cluster under test, +writing to the server config, creating and destroying databases, etc. It also +spins up various ephemeral TCP services. This is not safe for production servers +and therefore must be explicitly opted into by setting PG_TEST_EXTRA=python in +the environment. + +The test suite currently assumes that the standard PG* environment variables +point to the database under test and are sufficient to log in a superuser on +that system. In other words, a bare `psql` needs to Just Work before the test +suite can do its thing. For a newly built dev cluster, typically all that I need +to do is a + + export PGDATABASE=postgres + +but you can adjust as needed for your setup. See also 'Advanced Usage' below. + +## Requirements + +A supported version (3.6+) of Python. + +The first run of + + make installcheck PG_TEST_EXTRA=python + +will install a local virtual environment and all needed dependencies. During +development, if libpq changes incompatibly, you can issue + + $ make rebuild-psycopg2 + +to force a rebuild of the client library. + +## Hacking + +The code style is enforced by a _very_ opinionated autoformatter. Running the + + make indent + +recipe will invoke it for you automatically. Don't fight the tool; part of the +zen is in knowing that if the formatter makes your code ugly, there's probably a +cleaner way to write your code. + +## Advanced Usage + +The Makefile is there for convenience, but you don't have to use it. Activate +the virtualenv to be able to use pytest directly: + + $ export PG_TEST_EXTRA=python + $ source venv/bin/activate + $ py.test -k oauth + ... + $ py.test ./server/test_server.py + ... + $ deactivate # puts the PATH et al back the way it was before + +To make quick smoke tests possible, slow tests have been marked explicitly. You +can skip them by saying e.g. + + $ py.test -m 'not slow' + +If you'd rather not test against an existing server, you can have the suite spin +up a temporary one using whatever pg_ctl it finds in PATH: + + $ py.test --temp-instance=./tmp_check diff --git a/src/test/python/client/__init__.py b/src/test/python/client/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/test/python/client/conftest.py b/src/test/python/client/conftest.py new file mode 100644 index 0000000000000..20e72a404aa5a --- /dev/null +++ b/src/test/python/client/conftest.py @@ -0,0 +1,196 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import contextlib +import datetime +import functools +import ipaddress +import os +import socket +import sys +import threading + +import psycopg2 +import psycopg2.extras +import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import pq3 + +BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls + + +@pytest.fixture +def server_socket(unused_tcp_port_factory): + """ + Returns a listening socket bound to an ephemeral port. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", unused_tcp_port_factory())) + s.listen(1) + s.settimeout(BLOCKING_TIMEOUT) + yield s + + +class ClientHandshake(threading.Thread): + """ + A thread that connects to a local Postgres server using psycopg2. Once the + opening handshake completes, the connection will be immediately closed. + """ + + def __init__(self, *, port, **kwargs): + super().__init__() + + kwargs["port"] = port + self._kwargs = kwargs + + self.exception = None + + def run(self): + try: + conn = psycopg2.connect(host="127.0.0.1", **self._kwargs) + with contextlib.closing(conn): + self._pump_async(conn) + except Exception as e: + self.exception = e + + def check_completed(self, timeout=BLOCKING_TIMEOUT): + """ + Joins the client thread. Raises an exception if the thread could not be + joined, or if it threw an exception itself. (The exception will be + cleared, so future calls to check_completed will succeed.) + """ + self.join(timeout) + + if self.is_alive(): + raise TimeoutError("client thread did not handshake within the timeout") + elif self.exception: + e = self.exception + self.exception = None + raise e + + def _pump_async(self, conn): + """ + Polls a psycopg2 connection until it's completed. (Synchronous + connections will work here too; they'll just immediately return OK.) + """ + psycopg2.extras.wait_select(conn) + + +@pytest.fixture +def accept(server_socket): + """ + Returns a factory function that, when called, returns a pair (sock, client) + where sock is a server socket that has accepted a connection from client, + and client is an instance of ClientHandshake. Clients will complete their + handshakes and cleanly disconnect. + + The default connstring options may be extended or overridden by passing + arbitrary keyword arguments. Keep in mind that you generally should not + override the host or port, since they point to the local test server. + + For situations where a client needs to connect more than once to complete a + handshake, the accept function may be called more than once. (The client + returned for subsequent calls will always be the same client that was + returned for the first call.) + + Tests must either complete the handshake so that the client thread can be + automatically joined during teardown, or else call client.check_completed() + and manually handle any expected errors. + """ + _, port = server_socket.getsockname() + + client = None + default_opts = dict( + port=port, + user=pq3.pguser(), + sslmode="disable", + ) + + def factory(**kwargs): + nonlocal client + + if client is None: + opts = dict(default_opts) + opts.update(kwargs) + + # The server_socket is already listening, so the client thread can + # be safely started; it'll block on the connection until we accept. + client = ClientHandshake(**opts) + client.start() + + sock, _ = server_socket.accept() + sock.settimeout(BLOCKING_TIMEOUT) + return sock, client + + yield factory + + if client is not None: + client.check_completed() + + +@pytest.fixture +def conn(accept): + """ + Returns an accepted, wrapped pq3 connection to a psycopg2 client. The socket + will be closed when the test finishes, and the client will be checked for a + cleanly completed handshake. + """ + sock, client = accept() + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + yield conn + + +@pytest.fixture(scope="session") +def certpair(tmp_path_factory): + """ + Yields a (cert, key) pair of file paths that can be used by a TLS server. + The certificate is issued for "localhost" and its standard IPv4/6 addresses. + """ + + tmpdir = tmp_path_factory.mktemp("certs") + now = datetime.datetime.now(datetime.timezone.utc) + + # https://cryptography.io/en/latest/x509/tutorial/#creating-a-self-signed-certificate + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")]) + altNames = [ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + x509.IPAddress(ipaddress.IPv6Address("::1")), + ] + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(minutes=10)) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .add_extension(x509.SubjectAlternativeName(altNames), critical=False) + ).sign(key, hashes.SHA256()) + + # Writing the key with mode 0600 lets us use this from the server side, too. + keypath = str(tmpdir / "key.pem") + with open(keypath, "wb", opener=functools.partial(os.open, mode=0o600)) as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + certpath = str(tmpdir / "cert.pem") + with open(certpath, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + return certpath, keypath diff --git a/src/test/python/client/test_client.py b/src/test/python/client/test_client.py new file mode 100644 index 0000000000000..f066f37a6a961 --- /dev/null +++ b/src/test/python/client/test_client.py @@ -0,0 +1,187 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import base64 +import sys + +import psycopg2 +import pytest +from cryptography.hazmat.primitives import hashes, hmac + +import pq3 + +from .test_oauth import alt_patterns + + +def finish_handshake(conn): + """ + Sends the AuthenticationOK message and the standard opening salvo of server + messages, then asserts that the client immediately sends a Terminate message + to close the connection cleanly. + """ + pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK) + pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8") + pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY") + pq3.send(conn, pq3.types.BackendKeyData, pid=1234, key=0) + pq3.send(conn, pq3.types.ReadyForQuery, status=b"I") + + pkt = pq3.recv1(conn) + assert pkt.type == pq3.types.Terminate + + +def test_handshake(conn): + startup = pq3.recv1(conn, cls=pq3.Startup) + assert startup.proto == pq3.protocol(3, 0) + + finish_handshake(conn) + + +def test_aborted_connection(accept): + """ + Make sure the client correctly reports an early close during handshakes. + """ + sock, client = accept() + sock.close() + + expected = alt_patterns( + "server closed the connection unexpectedly", + # On some platforms, ECONNABORTED gets set instead. + "Software caused connection abort", + ) + with pytest.raises(psycopg2.OperationalError, match=expected): + client.check_completed() + + +# +# SCRAM-SHA-256 (see RFC 5802: https://tools.ietf.org/html/rfc5802) +# + + +@pytest.fixture +def password(): + """ + Returns a password for use by both client and server. + """ + # TODO: parameterize this with passwords that require SASLprep. + return "secret" + + +@pytest.fixture +def pwconn(accept, password): + """ + Like the conn fixture, but uses a password in the connection. + """ + sock, client = accept(password=password) + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + yield conn + + +def sha256(data): + """The H(str) function from Section 2.2.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data) + return digest.finalize() + + +def hmac_256(key, data): + """The HMAC(key, str) function from Section 2.2.""" + h = hmac.HMAC(key, hashes.SHA256()) + h.update(data) + return h.finalize() + + +def xor(a, b): + """The XOR operation from Section 2.2.""" + res = bytearray(a) + for i, byte in enumerate(b): + res[i] ^= byte + return bytes(res) + + +def h_i(data, salt, i): + """The Hi(str, salt, i) function from Section 2.2.""" + assert i > 0 + + acc = hmac_256(data, salt + b"\x00\x00\x00\x01") + last = acc + i -= 1 + + while i: + u = hmac_256(data, last) + acc = xor(acc, u) + + last = u + i -= 1 + + return acc + + +def test_scram(pwconn, password): + startup = pq3.recv1(pwconn, cls=pq3.Startup) + assert startup.proto == pq3.protocol(3, 0) + + pq3.send( + pwconn, + pq3.types.AuthnRequest, + type=pq3.authn.SASL, + body=[b"SCRAM-SHA-256", b""], + ) + + # Get the client-first-message. + pkt = pq3.recv1(pwconn) + assert pkt.type == pq3.types.PasswordMessage + + initial = pq3.SASLInitialResponse.parse(pkt.payload) + assert initial.name == b"SCRAM-SHA-256" + + c_bind, authzid, c_name, c_nonce = initial.data.split(b",") + assert c_bind == b"n" # no channel bindings on a plaintext connection + assert authzid == b"" # we don't support authzid currently + assert c_name == b"n=" # libpq doesn't honor the GS2 username + assert c_nonce.startswith(b"r=") + + # Send the server-first-message. + salt = b"12345" + iterations = 2 + + s_nonce = c_nonce + b"somenonce" + s_salt = b"s=" + base64.b64encode(salt) + s_iterations = b"i=%d" % iterations + + msg = b",".join([s_nonce, s_salt, s_iterations]) + pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=msg) + + # Get the client-final-message. + pkt = pq3.recv1(pwconn) + assert pkt.type == pq3.types.PasswordMessage + + c_bind_final, c_nonce_final, c_proof = pkt.payload.split(b",") + assert c_bind_final == b"c=" + base64.b64encode(c_bind + b"," + authzid + b",") + assert c_nonce_final == s_nonce + + # Calculate what the client proof should be. + salted_password = h_i(password.encode("ascii"), salt, iterations) + client_key = hmac_256(salted_password, b"Client Key") + stored_key = sha256(client_key) + + auth_message = b",".join( + [c_name, c_nonce, s_nonce, s_salt, s_iterations, c_bind_final, c_nonce_final] + ) + client_signature = hmac_256(stored_key, auth_message) + client_proof = xor(client_key, client_signature) + + expected = b"p=" + base64.b64encode(client_proof) + assert c_proof == expected + + # Send the correct server signature. + server_key = hmac_256(salted_password, b"Server Key") + server_signature = hmac_256(server_key, auth_message) + + s_verify = b"v=" + base64.b64encode(server_signature) + pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal, body=s_verify) + + # Done! + finish_handshake(pwconn) diff --git a/src/test/python/client/test_oauth.py b/src/test/python/client/test_oauth.py new file mode 100644 index 0000000000000..0635a84749e86 --- /dev/null +++ b/src/test/python/client/test_oauth.py @@ -0,0 +1,2737 @@ +# +# Copyright 2021 VMware, Inc. +# Portions Copyright 2023 Timescale, Inc. +# Portions Copyright 2024 PostgreSQL Global Development Group +# SPDX-License-Identifier: PostgreSQL +# + +import base64 +import collections +import contextlib +import ctypes +import http.server +import json +import logging +import os +import platform +import secrets +import socket +import ssl +import sys +import threading +import time +import traceback +import types +import urllib.parse +from numbers import Number + +import psycopg2 +import pytest + +import pq3 + +from .conftest import BLOCKING_TIMEOUT + +# The client tests need libpq to have been compiled with OAuth support; skip +# them otherwise. +pytestmark = pytest.mark.skipif( + os.getenv("with_libcurl") != "yes", + reason="OAuth client tests require --with-libcurl support", +) + +if platform.system() == "Darwin": + libpq = ctypes.cdll.LoadLibrary("libpq.5.dylib") +elif platform.system() == "Windows": + pass # TODO +else: + libpq = ctypes.cdll.LoadLibrary("libpq.so.5") + + +def finish_handshake(conn): + """ + Sends the AuthenticationOK message and the standard opening salvo of server + messages, then asserts that the client immediately sends a Terminate message + to close the connection cleanly. + """ + pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK) + pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8") + pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY") + pq3.send(conn, pq3.types.BackendKeyData, pid=1234, key=0) + pq3.send(conn, pq3.types.ReadyForQuery, status=b"I") + + pkt = pq3.recv1(conn) + assert pkt.type == pq3.types.Terminate + + +# +# OAUTHBEARER (see RFC 7628: https://tools.ietf.org/html/rfc7628) +# + + +def start_oauth_handshake(conn): + """ + Negotiates an OAUTHBEARER SASL challenge. Returns the client's initial + response data. + """ + startup = pq3.recv1(conn, cls=pq3.Startup) + assert startup.proto == pq3.protocol(3, 0) + + pq3.send( + conn, pq3.types.AuthnRequest, type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""] + ) + + pkt = pq3.recv1(conn) + assert pkt.type == pq3.types.PasswordMessage + + initial = pq3.SASLInitialResponse.parse(pkt.payload) + assert initial.name == b"OAUTHBEARER" + + return initial.data + + +def get_auth_value(initial): + """ + Finds the auth value (e.g. "Bearer somedata..." in the client's initial SASL + response. + """ + kvpairs = initial.split(b"\x01") + assert kvpairs[0] == b"n,," # no channel binding or authzid + assert kvpairs[2] == b"" # ends with an empty kvpair + assert kvpairs[3] == b"" # ...and there's nothing after it + assert len(kvpairs) == 4 + + key, value = kvpairs[1].split(b"=", 2) + assert key == b"auth" + + return value + + +def fail_oauth_handshake(conn, sasl_resp, *, errmsg="doesn't matter"): + """ + Sends a failure response via the OAUTHBEARER mechanism, consumes the + client's dummy response, and issues a FATAL error to end the exchange. + + sasl_resp is a dictionary which will be serialized as the OAUTHBEARER JSON + response. If provided, errmsg is used in the FATAL ErrorResponse. + """ + resp = json.dumps(sasl_resp) + pq3.send( + conn, + pq3.types.AuthnRequest, + type=pq3.authn.SASLContinue, + body=resp.encode("utf-8"), + ) + + # Per RFC, the client is required to send a dummy ^A response. + pkt = pq3.recv1(conn) + assert pkt.type == pq3.types.PasswordMessage + assert pkt.payload == b"\x01" + + # Now fail the SASL exchange. + pq3.send( + conn, + pq3.types.ErrorResponse, + fields=[ + b"SFATAL", + b"C28000", + b"M" + errmsg.encode("utf-8"), + b"", + ], + ) + + +def handle_discovery_connection(sock, discovery=None, *, response=None): + """ + Helper for all tests that expect an initial discovery connection from the + client. The provided discovery URI will be used in a standard error response + from the server (or response may be set, to provide a custom dictionary), + and the SASL exchange will be failed. + + By default, the client is expected to complete the entire handshake. Set + finish to False if the client should immediately disconnect when it receives + the error response. + """ + if response is None: + response = {"status": "invalid_token"} + if discovery is not None: + response["openid-configuration"] = discovery + + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + # Initiate a handshake. + initial = start_oauth_handshake(conn) + + # For discovery, the client should send an empty auth header. See RFC + # 7628, Sec. 4.3. + auth = get_auth_value(initial) + assert auth == b"" + + # The discovery handshake is doomed to fail. + fail_oauth_handshake(conn, response) + + +class RawResponse(str): + """ + Returned by registered endpoint callbacks to take full control of the + response. Usually, return values are converted to JSON; a RawResponse body + will be passed to the client as-is, allowing endpoint implementations to + issue invalid JSON. + """ + + pass + + +class RawBytes(bytes): + """ + Like RawResponse, but bypasses the UTF-8 encoding step as well, allowing + implementations to issue invalid encodings. + """ + + pass + + +class OpenIDProvider(threading.Thread): + """ + A thread that runs a mock OpenID provider server on an SSL-enabled socket. + """ + + def __init__(self, ssl_socket): + super().__init__() + + self.exception = None + + port = ssl_socket.getsockname()[1] + oauth = self._OAuthState() + + if socket.has_dualstack_ipv6(): + oauth.host = f"localhost:{port}" + oauth.issuer = f"https://localhost:{port}" + else: + oauth.host = f"127.0.0.1:{port}" + oauth.issuer = f"https://127.0.0.1:{port}" + + # The following endpoints are required to be advertised by providers, + # even though our chosen client implementation does not actually make + # use of them. + oauth.register_endpoint( + "authorization_endpoint", "POST", "/authorize", self._authorization_handler + ) + oauth.register_endpoint("jwks_uri", "GET", "/keys", self._jwks_handler) + + self.server = self._HTTPSServer(ssl_socket, self._Handler) + self.server.oauth = oauth + + def run(self): + try: + # XXX socketserver.serve_forever() has a serious architectural + # issue: its select loop wakes up every `poll_interval` seconds to + # see if the server is shutting down. The default, 500 ms, only lets + # us run two tests every second. But the faster we go, the more CPU + # we burn unnecessarily... + self.server.serve_forever(poll_interval=0.01) + except Exception as e: + self.exception = e + + def stop(self, timeout=BLOCKING_TIMEOUT): + """ + Shuts down the server and joins its thread. Raises an exception if the + thread could not be joined, or if it threw an exception itself. Must + only be called once, after start(). + """ + self.server.shutdown() + self.join(timeout) + + if self.is_alive(): + raise TimeoutError("client thread did not handshake within the timeout") + elif self.exception: + e = self.exception + raise e + + class _OAuthState(object): + def __init__(self): + self.endpoint_paths = {} + self._endpoints = {} + + # Provide a standard discovery document by default; tests can + # override it. + self.register_endpoint( + None, + "GET", + "/.well-known/openid-configuration", + self._default_discovery_handler, + ) + + # Default content type unless overridden. + self.content_type = "application/json" + + @property + def discovery_uri(self): + return f"{self.issuer}/.well-known/openid-configuration" + + def register_endpoint(self, name, method, path, func): + if method not in self._endpoints: + self._endpoints[method] = {} + + self._endpoints[method][path] = func + + if name is not None: + self.endpoint_paths[name] = path + + def endpoint(self, method, path): + if method not in self._endpoints: + return None + + return self._endpoints[method].get(path) + + def _default_discovery_handler(self, headers, params): + doc = { + "issuer": self.issuer, + "response_types_supported": ["token"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "grant_types_supported": [ + "authorization_code", + "urn:ietf:params:oauth:grant-type:device_code", + ], + } + + for name, path in self.endpoint_paths.items(): + doc[name] = self.issuer + path + + return 200, doc + + class _HTTPSServer(http.server.HTTPServer): + def __init__(self, ssl_socket, handler_cls): + # Attach the SSL socket to the server. We don't bind/activate since + # the socket is already listening. + super().__init__(None, handler_cls, bind_and_activate=False) + self.socket = ssl_socket + self.server_address = self.socket.getsockname() + + def shutdown_request(self, request): + # Cleanly unwrap the SSL socket before shutting down the connection; + # otherwise careful clients will complain about truncation. + try: + request = request.unwrap() + except (ssl.SSLEOFError, ConnectionResetError, BrokenPipeError): + # The client already closed (or aborted) the connection without + # a clean shutdown. This is seen on some platforms during tests + # that break the HTTP protocol. Just return and have the server + # close the socket. + return + except ssl.SSLError as err: + # FIXME OpenSSL 3.4 introduced an incompatibility with Python's + # TLS error handling, resulting in a bogus "[SYS] unknown error" + # on some platforms. Hopefully this is fixed in 2025's set of + # maintenance releases and this case can be removed. + # + # https://github.com/python/cpython/issues/127257 + # + if "[SYS] unknown error" in str(err): + return + raise + + super().shutdown_request(request) + + def handle_error(self, request, addr): + self.shutdown_request(request) + raise + + @staticmethod + def _jwks_handler(headers, params): + return 200, {"keys": []} + + @staticmethod + def _authorization_handler(headers, params): + # We don't actually want this to be called during these tests -- we + # should be using the device authorization endpoint instead. + assert ( + False + ), "authorization handler called instead of device authorization handler" + + class _Handler(http.server.BaseHTTPRequestHandler): + timeout = BLOCKING_TIMEOUT + + def _handle(self, *, params=None, handler=None): + oauth = self.server.oauth + assert self.headers["Host"] == oauth.host + + # XXX: BaseHTTPRequestHandler collapses leading slashes in the path + # to work around an open redirection vuln (gh-87389) in + # SimpleHTTPServer. But we're not using SimpleHTTPServer, and we + # want to test repeating leading slashes, so that's not very + # helpful. Put them back. + orig_path = self.raw_requestline.split()[1] + orig_path = str(orig_path, "iso-8859-1") + assert orig_path.endswith(self.path) # sanity check + self.path = orig_path + + if handler is None: + handler = oauth.endpoint(self.command, self.path) + assert ( + handler is not None + ), f"no registered endpoint for {self.command} {self.path}" + + result = handler(self.headers, params) + + if len(result) == 2: + headers = {"Content-Type": oauth.content_type} + code, resp = result + else: + code, headers, resp = result + + self.send_response(code) + for h, v in headers.items(): + self.send_header(h, v) + self.end_headers() + + if resp is not None: + if not isinstance(resp, RawBytes): + if not isinstance(resp, RawResponse): + resp = json.dumps(resp) + resp = resp.encode("utf-8") + self.wfile.write(resp) + + self.close_connection = True + + def do_GET(self): + self._handle() + + def _request_body(self): + length = self.headers["Content-Length"] + + # Handle only an explicit content-length. + assert length is not None + length = int(length) + + return self.rfile.read(length).decode("utf-8") + + def do_POST(self): + assert self.headers["Content-Type"] == "application/x-www-form-urlencoded" + + body = self._request_body() + if body: + # parse_qs() is understandably fairly lax when it comes to + # acceptable characters, but we're stricter. Spaces must be + # encoded, and they must use the '+' encoding rather than "%20". + assert " " not in body + assert "%20" not in body + + params = urllib.parse.parse_qs( + body, + keep_blank_values=True, + strict_parsing=True, + encoding="utf-8", + errors="strict", + ) + else: + params = {} + + self._handle(params=params) + + +@pytest.fixture(autouse=True) +def enable_client_oauth_debugging(monkeypatch): + """ + HTTP providers aren't allowed by default; enable them via envvar. + """ + monkeypatch.setenv("PGOAUTHDEBUG", "UNSAFE") + + +@pytest.fixture(autouse=True) +def trust_certpair_in_client(monkeypatch, certpair): + """ + Set a trusted CA file for OAuth client connections. + """ + monkeypatch.setenv("PGOAUTHCAFILE", certpair[0]) + + +@pytest.fixture(scope="session") +def ssl_socket(certpair): + """ + A listening server-side socket for SSL connections, using the certpair + fixture. + """ + # Try to listen on both IPv4 and v6, if possible, for extra coverage of Curl + # corner cases compared to the standard test suite. Otherwise just use IPv4. + if socket.has_dualstack_ipv6(): + sock = socket.create_server( + ("", 0), family=socket.AF_INET6, dualstack_ipv6=True + ) + else: + sock = socket.create_server(("", 0)) + + # The TLS connections we're making are incredibly sensitive to delayed ACKs + # from the client. (Without TCP_NODELAY, test performance degrades 4-5x.) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + with contextlib.closing(sock): + # Wrap the server socket for TLS. + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ctx.load_cert_chain(*certpair) + + yield ctx.wrap_socket(sock, server_side=True) + + +@pytest.fixture +def openid_provider(ssl_socket): + """ + A fixture that returns the OAuth state of a running OpenID provider server. The + server will be stopped when the fixture is torn down. + """ + thread = OpenIDProvider(ssl_socket) + thread.start() + + try: + yield thread.server.oauth + finally: + thread.stop() + + +# +# PQAuthDataHook implementation, matching libpq.h +# + + +PQAUTHDATA_PROMPT_OAUTH_DEVICE = 0 +PQAUTHDATA_OAUTH_BEARER_TOKEN = 1 + +PGRES_POLLING_FAILED = 0 +PGRES_POLLING_READING = 1 +PGRES_POLLING_WRITING = 2 +PGRES_POLLING_OK = 3 + + +class PGPromptOAuthDevice(ctypes.Structure): + _fields_ = [ + ("verification_uri", ctypes.c_char_p), + ("user_code", ctypes.c_char_p), + ("verification_uri_complete", ctypes.c_char_p), + ("expires_in", ctypes.c_int), + ] + + +class PGOAuthBearerRequest(ctypes.Structure): + pass + + +PGOAuthBearerRequest._fields_ = [ + ("openid_configuration", ctypes.c_char_p), + ("scope", ctypes.c_char_p), + ( + "async_", + ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.c_void_p, + ctypes.POINTER(PGOAuthBearerRequest), + ctypes.POINTER(ctypes.c_int), + ), + ), + ( + "cleanup", + ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.POINTER(PGOAuthBearerRequest)), + ), + ("token", ctypes.c_char_p), + ("user", ctypes.c_void_p), +] + + +@pytest.fixture +def auth_data_cb(): + """ + Tracks calls to the libpq authdata hook. The yielded object contains a calls + member that records the data sent to the hook. If a test needs to perform + custom actions during a call, it can set the yielded object's impl callback; + beware that the callback takes place on a different thread. + + This is done differently from the other callback implementations on purpose. + For the others, we can declare test-specific callbacks and have them perform + direct assertions on the data they receive. But that won't work for a C + callback, because there's no way for us to bubble up the assertion through + libpq. Instead, this mock-style approach is taken, where we just record the + calls and let the test examine them later. + """ + + class _Call: + pass + + class _cb(object): + def __init__(self): + self.calls = [] + + cb = _cb() + cb.impl = None + + # The callback will occur on a different thread, so protect the cb object. + cb_lock = threading.Lock() + + @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_byte, ctypes.c_void_p, ctypes.c_void_p) + def auth_data_cb(typ, pgconn, data): + handle_by_default = 0 # does an implementation have to be provided? + + if typ == PQAUTHDATA_PROMPT_OAUTH_DEVICE: + cls = PGPromptOAuthDevice + handle_by_default = 1 + elif typ == PQAUTHDATA_OAUTH_BEARER_TOKEN: + cls = PGOAuthBearerRequest + else: + return 0 + + call = _Call() + call.type = typ + + # The lifetime of the underlying data being pointed to doesn't + # necessarily match the lifetime of the Python object, so we can't + # reference a Structure's fields after returning. Explicitly copy the + # contents over, field by field. + data = ctypes.cast(data, ctypes.POINTER(cls)) + for name, _ in cls._fields_: + setattr(call, name, getattr(data.contents, name)) + + with cb_lock: + cb.calls.append(call) + + if cb.impl: + # Pass control back to the test. + try: + return cb.impl(typ, pgconn, data.contents) + except Exception: + # This can't escape into the C stack, but we can fail the flow + # and hope the traceback gives us enough detail. + logging.error( + "Exception during authdata hook callback:\n" + + traceback.format_exc() + ) + return -1 + + return handle_by_default + + libpq.PQsetAuthDataHook(auth_data_cb) + try: + yield cb + finally: + # The callback is about to go out of scope, so make sure libpq is + # disconnected from it. (We wouldn't want to accidentally influence + # later tests anyway.) + libpq.PQsetAuthDataHook(None) + + +@pytest.mark.parametrize( + "success, abnormal_failure", + [ + pytest.param(True, False, id="success"), + pytest.param(False, False, id="normal failure"), + pytest.param(False, True, id="abnormal failure"), + ], +) +@pytest.mark.parametrize("secret", [None, "", "hunter2"]) +@pytest.mark.parametrize("scope", [None, "", "openid email"]) +@pytest.mark.parametrize("retries", [0, 1]) +@pytest.mark.parametrize( + "content_type", + [ + pytest.param("application/json", id="standard"), + pytest.param("application/json;charset=utf-8", id="charset"), + pytest.param("application/json \t;\t charset=utf-8", id="charset (whitespace)"), + ], +) +@pytest.mark.parametrize("uri_spelling", ["verification_url", "verification_uri"]) +@pytest.mark.parametrize( + "asynchronous", + [ + pytest.param(False, id="synchronous"), + pytest.param(True, id="asynchronous"), + ], +) +def test_oauth_with_explicit_discovery_uri( + accept, + openid_provider, + asynchronous, + uri_spelling, + content_type, + retries, + scope, + secret, + auth_data_cb, + success, + abnormal_failure, +): + client_id = secrets.token_hex() + openid_provider.content_type = content_type + + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=client_id, + oauth_client_secret=secret, + oauth_scope=scope, + async_=asynchronous, + ) + + device_code = secrets.token_hex() + user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}" + verification_url = "https://example.com/device" + + access_token = secrets.token_urlsafe() + + def check_client_authn(headers, params): + if secret is None: + assert "Authorization" not in headers + assert params["client_id"] == [client_id] + return + + # Require the client to use Basic authn; request-body credentials are + # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1). + assert "Authorization" in headers + assert "client_id" not in params + + method, creds = headers["Authorization"].split() + assert method == "Basic" + + expected = f"{client_id}:{secret}" + assert base64.b64decode(creds) == expected.encode("ascii") + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + check_client_authn(headers, params) + + if scope: + assert params["scope"] == [scope] + else: + assert "scope" not in params + + resp = { + "device_code": device_code, + "user_code": user_code, + "interval": 0, + uri_spelling: verification_url, + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + attempts = 0 + retry_lock = threading.Lock() + + def token_endpoint(headers, params): + check_client_authn(headers, params) + + assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"] + assert params["device_code"] == [device_code] + + now = time.monotonic() + + with retry_lock: + nonlocal attempts + + # If the test wants to force the client to retry, return an + # authorization_pending response and decrement the retry count. + if attempts < retries: + attempts += 1 + return 400, {"error": "authorization_pending"} + + # Successfully finish the request by sending the access bearer token. + resp = { + "access_token": access_token, + "token_type": "bearer", + } + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + # First connection is a discovery request, which should result in the above + # endpoints being called. + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Client should reconnect. + sock, _ = accept() + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + # Validate and accept the token. + auth = get_auth_value(initial) + assert auth == f"Bearer {access_token}".encode("ascii") + + if success: + finish_handshake(conn) + + elif abnormal_failure: + # Send an empty error response, which should result in a + # mechanism-level failure in the client. This test ensures that + # the client doesn't try a third connection for this case. + expected_error = "server sent error response without a status" + fail_oauth_handshake(conn, {}) + + else: + # Simulate token validation failure. + resp = { + "status": "invalid_token", + "openid-configuration": openid_provider.discovery_uri, + } + expected_error = "test token validation failure" + fail_oauth_handshake(conn, resp, errmsg=expected_error) + + if retries: + # Finally, make sure that the client prompted the user once with the + # expected authorization URL and user code. + assert len(auth_data_cb.calls) == 2 + + # First call should have been for a custom flow, which we ignored. + assert auth_data_cb.calls[0].type == PQAUTHDATA_OAUTH_BEARER_TOKEN + + # Second call is for our user prompt. + call = auth_data_cb.calls[1] + assert call.type == PQAUTHDATA_PROMPT_OAUTH_DEVICE + assert call.verification_uri.decode() == verification_url + assert call.user_code.decode() == user_code + assert call.verification_uri_complete is None + assert call.expires_in == 5 + + if not success: + # The client should not try to connect again. + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +@pytest.mark.parametrize( + "server_discovery", + [ + pytest.param(True, id="server discovery"), + pytest.param(False, id="direct discovery"), + ], +) +@pytest.mark.parametrize( + "issuer, path", + [ + pytest.param( + "{issuer}", + "/.well-known/oauth-authorization-server", + id="oauth", + ), + pytest.param( + "{issuer}/alt", + "/.well-known/oauth-authorization-server/alt", + id="oauth with path, IETF style", + ), + pytest.param( + "{issuer}/alt", + "/alt/.well-known/oauth-authorization-server", + id="oauth with path, broken OIDC style", + ), + pytest.param( + "{issuer}/alt", + "/alt/.well-known/openid-configuration", + id="openid with path, OIDC style", + ), + pytest.param( + "{issuer}/alt", + "/.well-known/openid-configuration/alt", + id="openid with path, IETF style", + ), + pytest.param( + "{issuer}/", + "//.well-known/openid-configuration", + id="empty path segment, OIDC style", + ), + pytest.param( + "{issuer}/", + "/.well-known/openid-configuration/", + id="empty path segment, IETF style", + ), + ], +) +def test_alternate_well_known_paths( + accept, openid_provider, issuer, path, server_discovery +): + issuer = issuer.format(issuer=openid_provider.issuer) + discovery_uri = openid_provider.issuer + path + + client_id = secrets.token_hex() + access_token = secrets.token_urlsafe() + + def discovery_handler(*args): + """ + Pass-through implementation of the discovery handler. Modifies the + default document to contain this test's issuer identifier. + """ + code, doc = openid_provider._default_discovery_handler(*args) + doc["issuer"] = issuer + return code, doc + + openid_provider.register_endpoint(None, "GET", path, discovery_handler) + + def authorization_endpoint(headers, params): + resp = { + "device_code": "12345", + "user_code": "ABCDE", + "interval": 0, + "verification_url": "https://example.com/device", + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + # Successfully finish the request by sending the access bearer token. + resp = { + "access_token": access_token, + "token_type": "bearer", + } + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + kwargs = dict(oauth_client_id=client_id) + if server_discovery: + kwargs.update(oauth_issuer=issuer) + else: + kwargs.update(oauth_issuer=discovery_uri) + + sock, client = accept(**kwargs) + + with sock: + handle_discovery_connection(sock, discovery_uri) + + # Expect the client to connect again. + sock, _ = accept() + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + # Validate the token. + auth = get_auth_value(initial) + assert auth == f"Bearer {access_token}".encode("ascii") + + finish_handshake(conn) + + +@pytest.mark.parametrize( + "server_discovery", + [ + pytest.param(True, id="server discovery"), + pytest.param(False, id="direct discovery"), + ], +) +@pytest.mark.parametrize( + "issuer, path, expected_error", + [ + pytest.param( + "{issuer}", + "/.well-known/oauth-authorization-server/", + None, + id="extra empty segment (no path)", + ), + pytest.param( + "{issuer}/path", + "/.well-known/oauth-authorization-server/path/", + None, + id="extra empty segment (with path)", + ), + pytest.param( + "{issuer}", + "?/.well-known/oauth-authorization-server", + r'OAuth discovery URI ".*" must not contain query or fragment components', + id="query", + ), + pytest.param( + "{issuer}", + "#/.well-known/oauth-authorization-server", + r'OAuth discovery URI ".*" must not contain query or fragment components', + id="fragment", + ), + pytest.param( + "{issuer}/sub/path", + "/sub/.well-known/oauth-authorization-server/path", + r'OAuth discovery URI ".*" uses an invalid format', + id="sandwiched prefix", + ), + pytest.param( + "{issuer}/path", + "/path/openid-configuration", + r'OAuth discovery URI ".*" is not a .well-known URI', + id="not .well-known", + ), + pytest.param( + "{issuer}", + "https://.well-known/oauth-authorization-server", + r'OAuth discovery URI ".*" is not a .well-known URI', + id=".well-known prefix buried in the authority", + ), + pytest.param( + "{issuer}", + "/.well-known/oauth-protected-resource", + r'OAuth discovery URI ".*" uses an unsupported .well-known suffix', + id="unknown well-known suffix", + ), + pytest.param( + "{issuer}/path", + "/path/.well-known/openid-configuration-2", + r'OAuth discovery URI ".*" uses an unsupported .well-known suffix', + id="unknown well-known suffix, OIDC style", + ), + pytest.param( + "{issuer}/path", + "/.well-known/oauth-authorization-server-2/path", + r'OAuth discovery URI ".*" uses an unsupported .well-known suffix', + id="unknown well-known suffix, IETF style", + ), + pytest.param( + "{issuer}", + "file:///.well-known/oauth-authorization-server", + r'OAuth discovery URI ".*" must use HTTPS', + id="unsupported scheme", + ), + ], +) +def test_bad_well_known_paths( + accept, openid_provider, issuer, path, expected_error, server_discovery +): + if not server_discovery and "/.well-known/" not in path: + # An oauth_issuer without a /.well-known/ path segment is just a normal + # issuer identifier, so this isn't an interesting test. + pytest.skip("not interesting: direct discovery requires .well-known") + + issuer = issuer.format(issuer=openid_provider.issuer) + discovery_uri = urllib.parse.urljoin(openid_provider.issuer, path) + + client_id = secrets.token_hex() + + def discovery_handler(*args): + """ + Pass-through implementation of the discovery handler. Modifies the + default document to contain this test's issuer identifier. + """ + code, doc = openid_provider._default_discovery_handler(*args) + doc["issuer"] = issuer + return code, doc + + openid_provider.register_endpoint(None, "GET", path, discovery_handler) + + def fail(*args): + """ + No other endpoints should be contacted; fail if the client tries. + """ + assert False, "endpoint unexpectedly called" + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", fail + ) + openid_provider.register_endpoint("token_endpoint", "POST", "/token", fail) + + kwargs = dict(oauth_client_id=client_id) + if server_discovery: + kwargs.update(oauth_issuer=issuer) + else: + kwargs.update(oauth_issuer=discovery_uri) + + sock, client = accept(**kwargs) + with sock: + if expected_error and not server_discovery: + # If the client already knows the URL, it should disconnect as soon + # as it realizes it's not valid. + expect_disconnected_handshake(sock) + else: + # Otherwise, it should complete the connection. + handle_discovery_connection(sock, discovery_uri) + + # The client should not reconnect. + + if expected_error is None: + if server_discovery: + expected_error = rf"server's discovery document at {discovery_uri} \(issuer \".*\"\) is incompatible with oauth_issuer \({issuer}\)" + else: + expected_error = rf"the issuer identifier \({issuer}\) does not match oauth_issuer \(.*\)" + + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +def expect_disconnected_handshake(sock): + """ + Helper for any tests that expect the client to disconnect immediately after + being sent the OAUTHBEARER SASL method. Generally speaking, this requires + the client to have an oauth_issuer set so that it doesn't try to go through + discovery. + """ + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + # Initiate a handshake. + startup = pq3.recv1(conn, cls=pq3.Startup) + assert startup.proto == pq3.protocol(3, 0) + + pq3.send( + conn, + pq3.types.AuthnRequest, + type=pq3.authn.SASL, + body=[b"OAUTHBEARER", b""], + ) + + # The client should disconnect at this point. + assert not conn.read(1), "client sent unexpected data" + + +@pytest.mark.parametrize( + "missing", + [ + pytest.param(["oauth_issuer"], id="missing oauth_issuer"), + pytest.param(["oauth_client_id"], id="missing oauth_client_id"), + pytest.param(["oauth_client_id", "oauth_issuer"], id="missing both"), + ], +) +def test_oauth_requires_issuer_and_client_id(accept, openid_provider, missing): + params = dict( + oauth_issuer=openid_provider.issuer, + oauth_client_id="some-id", + ) + + # Remove required parameters. This should cause a client error after the + # server asks for OAUTHBEARER and the client tries to contact the issuer. + for k in missing: + del params[k] + + sock, client = accept(**params) + with sock: + expect_disconnected_handshake(sock) + + expected_error = "oauth_issuer and oauth_client_id are not both set" + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +# See https://datatracker.ietf.org/doc/html/rfc6749#appendix-A for character +# class definitions. +all_vschars = "".join([chr(c) for c in range(0x20, 0x7F)]) +all_nqchars = "".join([chr(c) for c in range(0x21, 0x7F) if c not in (0x22, 0x5C)]) + + +@pytest.mark.parametrize("client_id", ["", ":", " + ", r'+=&"\/~', all_vschars]) +@pytest.mark.parametrize("secret", [None, "", ":", " + ", r'+=&"\/~', all_vschars]) +@pytest.mark.parametrize("device_code", ["", " + ", r'+=&"\/~', all_vschars]) +@pytest.mark.parametrize("scope", ["&", r"+=&/", all_nqchars]) +def test_url_encoding(accept, openid_provider, client_id, secret, device_code, scope): + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=client_id, + oauth_client_secret=secret, + oauth_scope=scope, + ) + + user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}" + verification_url = "https://example.com/device" + + access_token = secrets.token_urlsafe() + + def check_client_authn(headers, params): + if secret is None: + assert "Authorization" not in headers + assert params["client_id"] == [client_id] + return + + # Require the client to use Basic authn; request-body credentials are + # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1). + assert "Authorization" in headers + assert "client_id" not in params + + method, creds = headers["Authorization"].split() + assert method == "Basic" + + decoded = base64.b64decode(creds).decode("utf-8") + username, password = decoded.split(":", 1) + + expected_username = urllib.parse.quote_plus(client_id) + expected_password = urllib.parse.quote_plus(secret) + + assert [username, password] == [expected_username, expected_password] + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + check_client_authn(headers, params) + + if scope: + assert params["scope"] == [scope] + else: + assert "scope" not in params + + resp = { + "device_code": device_code, + "user_code": user_code, + "interval": 0, + "verification_url": verification_url, + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + check_client_authn(headers, params) + + assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"] + assert params["device_code"] == [device_code] + + # Successfully finish the request by sending the access bearer token. + resp = { + "access_token": access_token, + "token_type": "bearer", + } + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + # First connection is a discovery request, which should result in the above + # endpoints being called. + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Second connection sends the token. + sock, _ = accept() + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + # Validate and accept the token. + auth = get_auth_value(initial) + assert auth == f"Bearer {access_token}".encode("ascii") + + finish_handshake(conn) + + +@pytest.mark.slow +@pytest.mark.parametrize("error_code", ["authorization_pending", "slow_down"]) +@pytest.mark.parametrize("retries", [1, 2]) +@pytest.mark.parametrize("omit_interval", [True, False]) +def test_oauth_retry_interval( + accept, openid_provider, omit_interval, retries, error_code +): + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id="some-id", + ) + + expected_retry_interval = 5 if omit_interval else 1 + access_token = secrets.token_urlsafe() + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + resp = { + "device_code": "my-device-code", + "user_code": "my-user-code", + "verification_uri": "https://example.com", + "expires_in": 5, + } + + if not omit_interval: + resp["interval"] = expected_retry_interval + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + attempts = 0 + last_retry = None + retry_lock = threading.Lock() + token_sent = threading.Event() + + def token_endpoint(headers, params): + now = time.monotonic() + + with retry_lock: + nonlocal attempts, last_retry, expected_retry_interval + + # Make sure the retry interval is being respected by the client. + if last_retry is not None: + interval = now - last_retry + assert interval >= expected_retry_interval + + last_retry = now + + # If the test wants to force the client to retry, return the desired + # error response and decrement the retry count. + if attempts < retries: + attempts += 1 + + # A slow_down code requires the client to additionally increase + # its interval by five seconds. + if error_code == "slow_down": + expected_retry_interval += 5 + + return 400, {"error": error_code} + + # Successfully finish the request by sending the access bearer token, + # and signal the main thread to continue. + resp = { + "access_token": access_token, + "token_type": "bearer", + } + token_sent.set() + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + # First connection is a discovery request, which should result in the above + # endpoints being called. + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # At this point the client is talking to the authorization server. Wait for + # that to succeed so we don't run into the accept() timeout. + token_sent.wait() + + # Client should reconnect and send the token. + sock, _ = accept() + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + # Validate and accept the token. + auth = get_auth_value(initial) + assert auth == f"Bearer {access_token}".encode("ascii") + + finish_handshake(conn) + + +@pytest.fixture +def self_pipe(): + """ + Yields a pipe fd pair. + """ + + class _Pipe: + pass + + p = _Pipe() + p.readfd, p.writefd = os.pipe() + + try: + yield p + finally: + os.close(p.readfd) + os.close(p.writefd) + + +@pytest.mark.parametrize("scope", [None, "", "openid email"]) +@pytest.mark.parametrize( + "retries", + [ + -1, # no async callback + 0, # async callback immediately returns token + 1, # async callback waits on altsock once + 2, # async callback waits on altsock twice + ], +) +@pytest.mark.parametrize( + "asynchronous", + [ + pytest.param(False, id="synchronous"), + pytest.param(True, id="asynchronous"), + ], +) +def test_user_defined_flow( + accept, auth_data_cb, self_pipe, scope, retries, asynchronous +): + issuer = "http://localhost" + discovery_uri = issuer + "/.well-known/openid-configuration" + access_token = secrets.token_urlsafe() + + sock, client = accept( + oauth_issuer=discovery_uri, + oauth_client_id="some-id", + oauth_scope=scope, + async_=asynchronous, + ) + + # Track callbacks. + attempts = 0 + wakeup_called = False + cleanup_calls = 0 + lock = threading.Lock() + + def wakeup(): + """Writes a byte to the wakeup pipe.""" + nonlocal wakeup_called + with lock: + wakeup_called = True + os.write(self_pipe.writefd, b"\0") + + def get_token(pgconn, request, p_altsock): + """ + Async token callback. While attempts < retries, libpq will be instructed + to wait on the self_pipe. When attempts == retries, the token will be + set. + + Note that assertions and exceptions raised here are allowed but not very + helpful, since they can't bubble through the libpq stack to be collected + by the test suite. Try not to rely too heavily on them. + """ + # Make sure libpq passed our user data through. + assert request.user == 42 + + with lock: + nonlocal attempts, wakeup_called + + if attempts: + # If we've already started the timer, we shouldn't get a + # call back before it trips. + assert wakeup_called, "authdata hook was called before the timer" + + # Drain the wakeup byte. + os.read(self_pipe.readfd, 1) + + if attempts < retries: + attempts += 1 + + # Wake up the client in a little bit of time. + wakeup_called = False + threading.Timer(0.1, wakeup).start() + + # Tell libpq to wait on the other end of the wakeup pipe. + p_altsock[0] = self_pipe.readfd + return PGRES_POLLING_READING + + # Done! + request.token = access_token.encode() + return PGRES_POLLING_OK + + @ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.c_void_p, + ctypes.POINTER(PGOAuthBearerRequest), + ctypes.POINTER(ctypes.c_int), + ) + def get_token_wrapper(pgconn, p_request, p_altsock): + """ + Translation layer between C and Python for the async callback. + Assertions and exceptions will be swallowed at the boundary, so make + sure they don't escape here. + """ + try: + return get_token(pgconn, p_request.contents, p_altsock) + except Exception: + logging.error("Exception during async callback:\n" + traceback.format_exc()) + return PGRES_POLLING_FAILED + + @ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.POINTER(PGOAuthBearerRequest)) + def cleanup(pgconn, p_request): + """ + Should be called exactly once per connection. + """ + nonlocal cleanup_calls + with lock: + cleanup_calls += 1 + + def bearer_hook(typ, pgconn, request): + """ + Implementation of the PQAuthDataHook, which either sets up an async + callback or returns the token directly, depending on the value of + retries. + + As above, try not to rely too much on assertions/exceptions here. + """ + assert typ == PQAUTHDATA_OAUTH_BEARER_TOKEN + request.cleanup = cleanup + + if retries < 0: + # Special case: return a token immediately without a callback. + request.token = access_token.encode() + return 1 + + # Tell libpq to call us back. + request.async_ = get_token_wrapper + request.user = ctypes.c_void_p(42) # will be checked in the callback + return 1 + + auth_data_cb.impl = bearer_hook + + # Now drive the server side. + if retries >= 0: + # First connection is a discovery request, which should result in the + # hook being invoked. + with sock: + handle_discovery_connection(sock, discovery_uri) + + # Client should reconnect to send the token. + sock, _ = accept() + + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + # Initiate a handshake, which should result in our custom callback + # being invoked to fetch the token. + initial = start_oauth_handshake(conn) + + # Validate and accept the token. + auth = get_auth_value(initial) + assert auth == f"Bearer {access_token}".encode("ascii") + + finish_handshake(conn) + + # Check the data provided to the hook. + assert len(auth_data_cb.calls) == 1 + + call = auth_data_cb.calls[0] + assert call.type == PQAUTHDATA_OAUTH_BEARER_TOKEN + assert call.openid_configuration.decode() == discovery_uri + assert call.scope == (None if scope is None else scope.encode()) + + # Make sure we clean up after ourselves when the connection is finished. + client.check_completed() + assert cleanup_calls == 1 + + +def alt_patterns(*patterns): + """ + Just combines multiple alternative regexes into one. It's not very efficient + but IMO it's easier to read and maintain. + """ + pat = "" + + for p in patterns: + if pat: + pat += "|" + pat += f"({p})" + + return pat + + +@pytest.mark.parametrize( + "failure_mode, error_pattern", + [ + pytest.param( + ( + 401, + { + "error": "invalid_client", + "error_description": "client authentication failed", + }, + ), + r"failed to obtain device authorization: client authentication failed \(invalid_client\)", + id="authentication failure with description", + ), + pytest.param( + (400, {"error": "invalid_request"}), + r"failed to obtain device authorization: \(invalid_request\)", + id="invalid request without description", + ), + pytest.param( + (400, {"error": "invalid_request", "padding": "x" * 256 * 1024}), + r"failed to obtain device authorization: response is too large", + id="gigantic authz response", + ), + pytest.param( + (200, RawResponse('{"":' + "[" * 16)), + r"failed to parse device authorization: JSON is too deeply nested", + id="overly nested authz response array", + ), + pytest.param( + (200, RawResponse('{"":' * 17)), + r"failed to parse device authorization: JSON is too deeply nested", + id="overly nested authz response object", + ), + pytest.param( + (400, {}), + r'failed to parse token error response: field "error" is missing', + id="broken error response", + ), + pytest.param( + (401, {"error": "invalid_client"}), + r"failed to obtain device authorization: provider requires client authentication, and no oauth_client_secret is set \(invalid_client\)", + id="failed authentication without description", + ), + pytest.param( + (200, RawResponse(r'{ "interval": 3.5.8 }')), + r"failed to parse device authorization: Token .* is invalid", + id="non-numeric interval", + ), + pytest.param( + (200, RawResponse(r'{ "interval": 08 }')), + r"failed to parse device authorization: Token .* is invalid", + id="invalid numeric interval", + ), + ], +) +def test_oauth_device_authorization_failures( + accept, openid_provider, failure_mode, error_pattern +): + client_id = secrets.token_hex() + + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=client_id, + ) + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + return failure_mode + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + assert False, "token endpoint was invoked unexpectedly" + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Now make sure the client correctly failed. + with pytest.raises(psycopg2.OperationalError, match=error_pattern): + client.check_completed() + + +Missing = object() # sentinel for test_oauth_device_authorization_bad_json() + + +@pytest.mark.parametrize( + "bad_value", + [ + pytest.param({"device_code": 3}, id="object"), + pytest.param([1, 2, 3], id="array"), + pytest.param("some string", id="string"), + pytest.param(4, id="numeric"), + pytest.param(False, id="boolean"), + pytest.param(None, id="null"), + pytest.param(Missing, id="missing"), + ], +) +@pytest.mark.parametrize( + "field_name,ok_type,required", + [ + ("device_code", str, True), + ("user_code", str, True), + ("verification_uri", str, True), + ("interval", int, False), + ], +) +def test_oauth_device_authorization_bad_json_schema( + accept, openid_provider, field_name, ok_type, required, bad_value +): + # To make the test matrix easy, just skip the tests that aren't actually + # interesting (field of the correct type, missing optional field). + if bad_value is Missing and not required: + pytest.skip("not interesting: optional field") + elif type(bad_value) == ok_type: # not isinstance(), because bool is an int + pytest.skip("not interesting: correct type") + + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=secrets.token_hex(), + ) + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + # Begin with an acceptable base response... + resp = { + "device_code": "my-device-code", + "user_code": "my-user-code", + "interval": 0, + "verification_uri": "https://example.com", + "expires_in": 5, + } + + # ...then tweak it so the client fails. + if bad_value is Missing: + del resp[field_name] + else: + resp[field_name] = bad_value + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + assert False, "token endpoint was invoked unexpectedly" + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Now make sure the client correctly failed. + if bad_value is Missing: + error_pattern = f'field "{field_name}" is missing' + elif ok_type == str: + error_pattern = f'field "{field_name}" must be a string' + elif ok_type == int: + error_pattern = f'field "{field_name}" must be a number' + else: + assert False, "update error_pattern for new failure mode" + + with pytest.raises(psycopg2.OperationalError, match=error_pattern): + client.check_completed() + + +@pytest.mark.parametrize( + "failure_mode, error_pattern", + [ + pytest.param( + ( + 400, + { + "error": "expired_token", + "error_description": "the device code has expired", + }, + ), + r"failed to obtain access token: the device code has expired \(expired_token\)", + id="expired token with description", + ), + pytest.param( + (400, {"error": "access_denied"}), + r"failed to obtain access token: \(access_denied\)", + id="access denied without description", + ), + pytest.param( + (400, {"error": "access_denied", "padding": "x" * 256 * 1024}), + r"failed to obtain access token: response is too large", + id="gigantic token response", + ), + pytest.param( + (200, RawResponse('{"":' + "[" * 16)), + r"failed to parse access token response: JSON is too deeply nested", + id="overly nested token response array", + ), + pytest.param( + (200, RawResponse('{"":' * 17)), + r"failed to parse access token response: JSON is too deeply nested", + id="overly nested token response object", + ), + pytest.param( + (400, {}), + r'failed to parse token error response: field "error" is missing', + id="empty error response", + ), + pytest.param( + (401, {"error": "invalid_client"}), + r"failed to obtain access token: provider requires client authentication, and no oauth_client_secret is set \(invalid_client\)", + id="authentication failure without description", + ), + pytest.param( + (200, {}, {}), + r"failed to parse access token response: no content type was provided", + id="missing content type", + ), + pytest.param( + (200, {"Content-Type": "text/plain"}, {}), + r"failed to parse access token response: unexpected content type", + id="wrong content type", + ), + pytest.param( + (200, {"Content-Type": "application/jsonx"}, {}), + r"failed to parse access token response: unexpected content type", + id="wrong content type (correct prefix)", + ), + ], +) +@pytest.mark.parametrize("retries", [0, 1]) +def test_oauth_token_failures( + accept, openid_provider, retries, failure_mode, error_pattern +): + client_id = secrets.token_hex() + + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=client_id, + ) + + device_code = secrets.token_hex() + user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}" + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + assert params["client_id"] == [client_id] + + resp = { + "device_code": device_code, + "user_code": user_code, + "interval": 0, + "verification_uri": "https://example.com/device", + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + retry_lock = threading.Lock() + final_sent = False + + def token_endpoint(headers, params): + with retry_lock: + nonlocal retries, final_sent + + # If the test wants to force the client to retry, return an + # authorization_pending response and decrement the retry count. + if retries > 0: + retries -= 1 + return 400, {"error": "authorization_pending"} + + # We should only return our failure_mode response once; any further + # requests indicate that the client isn't correctly bailing out. + assert not final_sent, "client continued after token error" + + final_sent = True + + return failure_mode + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Now make sure the client correctly failed. + with pytest.raises(psycopg2.OperationalError, match=error_pattern): + client.check_completed() + + +@pytest.mark.parametrize( + "bad_value", + [ + pytest.param({"device_code": 3}, id="object"), + pytest.param([1, 2, 3], id="array"), + pytest.param("some string", id="string"), + pytest.param(4, id="numeric"), + pytest.param(False, id="boolean"), + pytest.param(None, id="null"), + pytest.param(Missing, id="missing"), + ], +) +@pytest.mark.parametrize( + "field_name,ok_type,required", + [ + ("access_token", str, True), + ("token_type", str, True), + ], +) +def test_oauth_token_bad_json_schema( + accept, openid_provider, field_name, ok_type, required, bad_value +): + # To make the test matrix easy, just skip the tests that aren't actually + # interesting (field of the correct type, missing optional field). + if bad_value is Missing and not required: + pytest.skip("not interesting: optional field") + elif type(bad_value) == ok_type: # not isinstance(), because bool is an int + pytest.skip("not interesting: correct type") + + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=secrets.token_hex(), + ) + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + resp = { + "device_code": "my-device-code", + "user_code": "my-user-code", + "interval": 0, + "verification_uri": "https://example.com", + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + # Begin with an acceptable base response... + resp = { + "access_token": secrets.token_urlsafe(), + "token_type": "bearer", + } + + # ...then tweak it so the client fails. + if bad_value is Missing: + del resp[field_name] + else: + resp[field_name] = bad_value + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Now make sure the client correctly failed. + error_pattern = "failed to parse access token response: " + if bad_value is Missing: + error_pattern += f'field "{field_name}" is missing' + elif ok_type == str: + error_pattern += f'field "{field_name}" must be a string' + elif ok_type == int: + error_pattern += f'field "{field_name}" must be a number' + else: + assert False, "update error_pattern for new failure mode" + + with pytest.raises(psycopg2.OperationalError, match=error_pattern): + client.check_completed() + + +@pytest.mark.parametrize("success", [True, False]) +@pytest.mark.parametrize("scope", [None, "openid email"]) +@pytest.mark.parametrize( + "base_response", + [ + {"status": "invalid_token"}, + {"extra_object": {"key": "value"}, "status": "invalid_token"}, + {"extra_object": {"status": 1}, "status": "invalid_token"}, + ], +) +def test_oauth_discovery(accept, openid_provider, base_response, scope, success): + sock, client = accept( + oauth_issuer=openid_provider.issuer, + oauth_client_id=secrets.token_hex(), + ) + + device_code = secrets.token_hex() + user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}" + verification_url = "https://example.com/device" + + access_token = secrets.token_urlsafe() + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + if scope: + assert params["scope"] == [scope] + else: + assert "scope" not in params + + resp = { + "device_code": device_code, + "user_code": user_code, + "interval": 0, + "verification_uri": verification_url, + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"] + assert params["device_code"] == [device_code] + + # Successfully finish the request by sending the access bearer token. + resp = { + "access_token": access_token, + "token_type": "bearer", + } + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + # Construct the response to use when failing the SASL exchange. Return a + # link to the discovery document, pointing to the test provider server. + fail_resp = { + **base_response, + "openid-configuration": openid_provider.discovery_uri, + } + + if scope: + fail_resp["scope"] = scope + + with sock: + handle_discovery_connection(sock, response=fail_resp) + + # The client will connect to us a second time, using the parameters we sent + # it. + sock, _ = accept() + + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + # Validate the token. + auth = get_auth_value(initial) + assert auth == f"Bearer {access_token}".encode("ascii") + + if success: + finish_handshake(conn) + + else: + # Simulate token validation failure. + expected_error = "test token validation failure" + fail_oauth_handshake(conn, fail_resp, errmsg=expected_error) + + if not success: + # The client should not try to connect again. + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +@pytest.mark.parametrize( + "response,expected_error", + [ + pytest.param( + "abcde", + 'Token "abcde" is invalid', + id="bad JSON: invalid syntax", + ), + pytest.param( + b"\xff\xff\xff\xff", + "server's error response is not valid UTF-8", + id="bad JSON: invalid encoding", + ), + pytest.param( + '"abcde"', + "top-level element must be an object", + id="bad JSON: top-level element is a string", + ), + pytest.param( + "[]", + "top-level element must be an object", + id="bad JSON: top-level element is an array", + ), + pytest.param( + "{}", + "server sent error response without a status", + id="bad JSON: no status member", + ), + pytest.param( + '{ "status": null }', + 'field "status" must be a string', + id="bad JSON: null status member", + ), + pytest.param( + '{ "status": 0 }', + 'field "status" must be a string', + id="bad JSON: int status member", + ), + pytest.param( + '{ "status": [ "bad" ] }', + 'field "status" must be a string', + id="bad JSON: array status member", + ), + pytest.param( + '{ "status": { "bad": "bad" } }', + 'field "status" must be a string', + id="bad JSON: object status member", + ), + pytest.param( + '{ "nested": { "status": "bad" } }', + "server sent error response without a status", + id="bad JSON: nested status", + ), + pytest.param( + '{ "status": "invalid_token" ', + "The input string ended unexpectedly", + id="bad JSON: unterminated object", + ), + pytest.param( + '{ "status": "invalid_token" } { }', + 'Expected end of input, but found "{"', + id="bad JSON: trailing data", + ), + pytest.param( + '{ "status": "invalid_token", "openid-configuration": 1 }', + 'field "openid-configuration" must be a string', + id="bad JSON: int openid-configuration member", + ), + pytest.param( + '{ "status": "invalid_token", "openid-configuration": 1 }', + 'field "openid-configuration" must be a string', + id="bad JSON: int openid-configuration member", + ), + pytest.param( + '{ "status": "invalid_token", "openid-configuration": "", "openid-configuration": "" }', + 'field "openid-configuration" is duplicated', + id="bad JSON: duplicated field", + ), + pytest.param( + '{ "status": "invalid_token", "scope": 1 }', + 'field "scope" must be a string', + id="bad JSON: int scope member", + ), + ], +) +def test_oauth_discovery_server_error(accept, response, expected_error): + sock, client = accept( + oauth_issuer="https://example.com", + oauth_client_id=secrets.token_hex(), + ) + + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + if isinstance(response, str): + response = response.encode("utf-8") + + # Fail the SASL exchange with an invalid JSON response. + pq3.send( + conn, + pq3.types.AuthnRequest, + type=pq3.authn.SASLContinue, + body=response, + ) + + # The client should disconnect, so the socket is closed here. (If + # the client doesn't disconnect, it will report a different error + # below and the test will fail.) + + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +# All of these tests are expected to fail before libpq tries to actually attempt +# a connection to any endpoint. To avoid hitting the network in the event that a +# test fails, an invalid IPv4 address (256.256.256.256) is used as a hostname. +@pytest.mark.parametrize( + "bad_response,expected_error", + [ + pytest.param( + (200, {"Content-Type": "text/plain"}, {}), + r'failed to parse OpenID discovery document: unexpected content type: "text/plain"', + id="not JSON", + ), + pytest.param( + (200, {}, {}), + r"failed to parse OpenID discovery document: no content type was provided", + id="no Content-Type", + ), + pytest.param( + (204, {}, None), + r"failed to fetch OpenID discovery document: unexpected response code 204", + id="no content", + ), + pytest.param( + (301, {"Location": "https://localhost/"}, None), + r"failed to fetch OpenID discovery document: unexpected response code 301", + id="redirection", + ), + pytest.param( + (404, {}), + r"failed to fetch OpenID discovery document: unexpected response code 404", + id="not found", + ), + pytest.param( + (200, RawResponse("blah\x00blah")), + r"failed to parse OpenID discovery document: response contains embedded NULLs", + id="NULL bytes in document", + ), + pytest.param( + (200, RawBytes(b"blah\xffblah")), + r"failed to parse OpenID discovery document: response is not valid UTF-8", + id="document is not UTF-8", + ), + pytest.param( + (200, 123), + r"failed to parse OpenID discovery document: top-level element must be an object", + id="scalar at top level", + ), + pytest.param( + (200, []), + r"failed to parse OpenID discovery document: top-level element must be an object", + id="array at top level", + ), + pytest.param( + (200, RawResponse("{")), + r"failed to parse OpenID discovery document.* input string ended unexpectedly", + id="unclosed object", + ), + pytest.param( + (200, RawResponse(r'{ "hello": ] }')), + r"failed to parse OpenID discovery document.* Expected JSON value", + id="bad array", + ), + pytest.param( + (200, {"issuer": 123}), + r'failed to parse OpenID discovery document: field "issuer" must be a string', + id="non-string issuer", + ), + pytest.param( + (200, {"issuer": ["something"]}), + r'failed to parse OpenID discovery document: field "issuer" must be a string', + id="issuer array", + ), + pytest.param( + (200, {"issuer": {}}), + r'failed to parse OpenID discovery document: field "issuer" must be a string', + id="issuer object", + ), + pytest.param( + (200, {"grant_types_supported": 123}), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="numeric grant types field", + ), + pytest.param( + ( + 200, + { + "grant_types_supported": "urn:ietf:params:oauth:grant-type:device_code" + }, + ), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="string grant types field", + ), + pytest.param( + (200, {"grant_types_supported": {}}), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="object grant types field", + ), + pytest.param( + (200, {"grant_types_supported": [123]}), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="non-string grant types", + ), + pytest.param( + (200, {"grant_types_supported": ["something", 123]}), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="non-string grant types later in the list", + ), + pytest.param( + (200, {"grant_types_supported": ["something", {}]}), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="object grant types later in the list", + ), + pytest.param( + (200, {"grant_types_supported": ["something", ["something"]]}), + r'failed to parse OpenID discovery document: field "grant_types_supported" must be an array of strings', + id="embedded array grant types later in the list", + ), + pytest.param( + ( + 200, + { + "grant_types_supported": ["something"], + "token_endpoint": "https://256.256.256.256/", + "issuer": 123, + }, + ), + r'failed to parse OpenID discovery document: field "issuer" must be a string', + id="non-string issuer after other valid fields", + ), + pytest.param( + ( + 200, + { + "ignored": {"grant_types_supported": 123, "token_endpoint": 123}, + "issuer": 123, + }, + ), + r'failed to parse OpenID discovery document: field "issuer" must be a string', + id="non-string issuer after other ignored fields", + ), + pytest.param( + (200, {"token_endpoint": "https://256.256.256.256/"}), + r'failed to parse OpenID discovery document: field "issuer" is missing', + id="missing issuer", + ), + pytest.param( + (200, {"issuer": "{issuer}"}), + r'failed to parse OpenID discovery document: field "token_endpoint" is missing', + id="missing token endpoint", + ), + pytest.param( + ( + 200, + { + "issuer": "{issuer}", + "token_endpoint": "https://256.256.256.256/token", + "grant_types_supported": [ + "urn:ietf:params:oauth:grant-type:device_code" + ], + }, + ), + r'cannot run OAuth device authorization: issuer "https://.*" does not provide a device authorization endpoint', + id="missing device_authorization_endpoint", + ), + pytest.param( + ( + 200, + { + "issuer": "{issuer}", + "token_endpoint": "https://256.256.256.256/token", + "grant_types_supported": [ + "urn:ietf:params:oauth:grant-type:device_code" + ], + "device_authorization_endpoint": "https://256.256.256.256/dev", + "filler": "x" * 256 * 1024, + }, + ), + r"failed to fetch OpenID discovery document: response is too large", + id="gigantic discovery response", + ), + pytest.param( + ( + 200, + RawResponse('{"":' + "[" * 16), + ), + r"failed to parse OpenID discovery document: JSON is too deeply nested", + id="overly nested discovery response array", + ), + pytest.param( + ( + 200, + RawResponse('{"":' * 17), + ), + r"failed to parse OpenID discovery document: JSON is too deeply nested", + id="overly nested discovery response object", + ), + pytest.param( + ( + 200, + { + "issuer": "{issuer}/path", + "token_endpoint": "https://256.256.256.256/token", + "grant_types_supported": [ + "urn:ietf:params:oauth:grant-type:device_code" + ], + "device_authorization_endpoint": "https://256.256.256.256/dev", + }, + ), + r"failed to parse OpenID discovery document: the issuer identifier \(https://.*/path\) does not match oauth_issuer \(https://.*\)", + id="mismatched issuer identifier", + ), + pytest.param( + ( + 200, + RawResponse( + """{ + "issuer": "https://256.256.256.256/path", + "token_endpoint": "https://256.256.256.256/token", + "grant_types_supported": [ + "urn:ietf:params:oauth:grant-type:device_code" + ], + "device_authorization_endpoint": "https://256.256.256.256/dev", + "device_authorization_endpoint": "https://256.256.256.256/dev" + }""" + ), + ), + r'failed to parse OpenID discovery document: field "device_authorization_endpoint" is duplicated', + id="duplicated field", + ), + # + # Exercise HTTP-level failures by breaking the protocol. Note that the + # error messages here are implementation-dependent. + # + pytest.param( + (1000, {}), + r"failed to fetch OpenID discovery document: Unsupported protocol \(.*\)", + id="invalid HTTP response code", + ), + pytest.param( + (200, {"Content-Length": -1}, {}), + r"failed to fetch OpenID discovery document: Weird server reply \(.*Content-Length.*\)", + id="bad HTTP Content-Length", + ), + ], +) +def test_oauth_discovery_provider_failure( + accept, openid_provider, bad_response, expected_error +): + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=secrets.token_hex(), + ) + + def failing_discovery_handler(headers, params): + try: + # Insert the correct issuer value if the test wants to. + resp = bad_response[1] + iss = resp["issuer"] + resp["issuer"] = iss.format(issuer=openid_provider.issuer) + except (AttributeError, KeyError, TypeError): + pass + + return bad_response + + openid_provider.register_endpoint( + None, + "GET", + "/.well-known/openid-configuration", + failing_discovery_handler, + ) + + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +@pytest.mark.parametrize( + "sasl_err,resp_type,resp_payload,expected_error", + [ + pytest.param( + {"status": "invalid_request"}, + pq3.types.ErrorResponse, + dict( + fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""], + ), + "server rejected OAuth bearer token: invalid_request", + id="standard server error: invalid_request", + ), + pytest.param( + {"": [[[[[[[]]]]]]], "status": "invalid_request"}, + pq3.types.ErrorResponse, + dict( + fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""], + ), + "server rejected OAuth bearer token: invalid_request", + id="standard server error: invalid_request with ignored array", + ), + pytest.param( + {"status": "invalid_token"}, + pq3.types.ErrorResponse, + dict( + fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""], + ), + "expected error message", + id="standard server error: invalid_token without discovery URI", + ), + pytest.param( + {"status": "invalid_token", "openid-configuration": ""}, + pq3.types.AuthnRequest, + dict(type=pq3.authn.SASLContinue, body=b""), + "server sent additional OAuth data", + id="broken server: additional challenge after error", + ), + pytest.param( + {"status": "invalid_token", "openid-configuration": ""}, + pq3.types.AuthnRequest, + dict(type=pq3.authn.SASLFinal), + "server sent additional OAuth data", + id="broken server: SASL success after error", + ), + pytest.param( + {"status": "invalid_token", "openid-configuration": ""}, + pq3.types.AuthnRequest, + dict(type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]), + "duplicate SASL authentication request", + id="broken server: SASL reinitialization after error", + ), + pytest.param( + RawResponse('{"":' + "[" * 8), + pq3.types.AuthnRequest, + dict(type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]), + "JSON is too deeply nested", + id="broken server: overly nested JSON response array", + ), + pytest.param( + RawResponse('{"":' * 9), + pq3.types.AuthnRequest, + dict(type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]), + "JSON is too deeply nested", + id="broken server: overly nested JSON response object", + ), + ], +) +def test_oauth_server_error( + accept, auth_data_cb, sasl_err, resp_type, resp_payload, expected_error +): + wkuri = f"https://256.256.256.256/.well-known/openid-configuration" + sock, client = accept( + oauth_issuer=wkuri, + oauth_client_id="some-id", + ) + + def bearer_hook(typ, pgconn, request): + """ + Implementation of the PQAuthDataHook, which returns a token directly so + we don't need an openid_provider instance. + """ + assert typ == PQAUTHDATA_OAUTH_BEARER_TOKEN + request.token = secrets.token_urlsafe().encode() + return 1 + + auth_data_cb.impl = bearer_hook + + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + start_oauth_handshake(conn) + + # Ignore the client data. Return an error "challenge". + if isinstance(sasl_err, RawResponse): + resp = sasl_err + else: + if "openid-configuration" in sasl_err: + sasl_err["openid-configuration"] = wkuri + + resp = json.dumps(sasl_err) + + resp = resp.encode("utf-8") + pq3.send( + conn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=resp + ) + + # Per RFC, the client is required to send a dummy ^A response. + pkt = pq3.recv1(conn) + assert pkt.type == pq3.types.PasswordMessage + assert pkt.payload == b"\x01" + + # Now fail the SASL exchange (in either a valid way, or an + # invalid one, depending on the test). + pq3.send(conn, resp_type, **resp_payload) + + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +def test_oauth_interval_overflow(accept, openid_provider): + """ + A really badly behaved server could send a huge interval and then + immediately tell us to slow_down; ensure we handle this without breaking. + """ + # (should be equivalent to the INT_MAX in limits.h) + int_max = ctypes.c_uint(-1).value // 2 + + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id=secrets.token_hex(), + ) + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + resp = { + "device_code": "my-device-code", + "user_code": "my-user-code", + "verification_uri": "https://example.com", + "expires_in": 5, + "interval": int_max, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + return 400, {"error": "slow_down"} + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + expected_error = "slow_down interval overflow" + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +def test_oauth_refuses_http(accept, openid_provider, monkeypatch): + """ + HTTP must be refused without PGOAUTHDEBUG. + """ + monkeypatch.delenv("PGOAUTHDEBUG") + + def to_http(uri): + """Swaps out a URI's scheme for http.""" + parts = urllib.parse.urlparse(uri) + parts = parts._replace(scheme="http") + return urllib.parse.urlunparse(parts) + + sock, client = accept( + oauth_issuer=to_http(openid_provider.issuer), + oauth_client_id=secrets.token_hex(), + ) + + # No provider callbacks necessary; we should fail immediately. + + with sock: + handle_discovery_connection(sock, to_http(openid_provider.discovery_uri)) + + expected_error = r'OAuth discovery URI ".*" must use HTTPS' + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +@pytest.mark.parametrize("auth_type", [pq3.authn.OK, pq3.authn.SASLFinal]) +def test_discovery_incorrectly_permits_connection(accept, auth_type): + """ + Incorrectly responds to a client's discovery request with AuthenticationOK + or AuthenticationSASLFinal. require_auth=oauth should catch the former, and + the mechanism itself should catch the latter. + """ + issuer = "https://256.256.256.256" + sock, client = accept( + oauth_issuer=issuer, + oauth_client_id=secrets.token_hex(), + require_auth="oauth", + ) + + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + + auth = get_auth_value(initial) + assert auth == b"" + + # Incorrectly log the client in. It should immediately disconnect. + pq3.send(conn, pq3.types.AuthnRequest, type=auth_type) + assert not conn.read(1), "client sent unexpected data" + + if auth_type == pq3.authn.OK: + expected_error = "server did not complete authentication" + else: + expected_error = "server sent unexpected additional OAuth data" + + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +def test_no_discovery_url_provided(accept): + """ + Tests what happens when the client doesn't know who to contact and the + server doesn't tell it. + """ + issuer = "https://256.256.256.256" + sock, client = accept( + oauth_issuer=issuer, + oauth_client_id=secrets.token_hex(), + ) + + with sock: + handle_discovery_connection(sock, discovery=None) + + expected_error = "no discovery metadata was provided" + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + + +@pytest.mark.parametrize("change_between_connections", [False, True]) +def test_discovery_url_changes(accept, openid_provider, change_between_connections): + """ + Ensures that the client complains if the server agrees on the issuer, but + disagrees on the discovery URL to be used. + """ + + # Set up our provider callbacks. + # NOTE that these callbacks will be called on a background thread. Don't do + # any unprotected state mutation here. + + def authorization_endpoint(headers, params): + resp = { + "device_code": "DEV", + "user_code": "USER", + "interval": 0, + "verification_uri": "https://example.org", + "expires_in": 5, + } + + return 200, resp + + openid_provider.register_endpoint( + "device_authorization_endpoint", "POST", "/device", authorization_endpoint + ) + + def token_endpoint(headers, params): + resp = { + "access_token": secrets.token_urlsafe(), + "token_type": "bearer", + } + + return 200, resp + + openid_provider.register_endpoint( + "token_endpoint", "POST", "/token", token_endpoint + ) + + # Have the client connect. + sock, client = accept( + oauth_issuer=openid_provider.discovery_uri, + oauth_client_id="some-id", + ) + + other_wkuri = f"{openid_provider.issuer}/.well-known/oauth-authorization-server" + + if not change_between_connections: + # Immediately respond with the wrong URL. + with sock: + handle_discovery_connection(sock, other_wkuri) + + else: + # First connection; use the right URL to begin with. + with sock: + handle_discovery_connection(sock, openid_provider.discovery_uri) + + # Second connection. Reject the token and switch the URL. + sock, _ = accept() + with sock: + with pq3.wrap(sock, debug_stream=sys.stdout) as conn: + initial = start_oauth_handshake(conn) + get_auth_value(initial) + + # Ignore the token; fail with a different discovery URL. + resp = { + "status": "invalid_token", + "openid-configuration": other_wkuri, + } + fail_oauth_handshake(conn, resp) + + expected_error = rf"server's discovery document has moved to {other_wkuri} \(previous location was {openid_provider.discovery_uri}\)" + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() diff --git a/src/test/python/conftest.py b/src/test/python/conftest.py new file mode 100644 index 0000000000000..1a73865ee4778 --- /dev/null +++ b/src/test/python/conftest.py @@ -0,0 +1,34 @@ +# +# Copyright 2023 Timescale, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import os + +import pytest + + +def pytest_addoption(parser): + """ + Adds custom command line options to py.test. We add one to signal temporary + Postgres instance creation for the server tests. + + Per pytest documentation, this must live in the top level test directory. + """ + parser.addoption( + "--temp-instance", + metavar="DIR", + help="create a temporary Postgres instance in DIR", + ) + + +@pytest.fixture(scope="session", autouse=True) +def _check_PG_TEST_EXTRA(request): + """ + Automatically skips the whole suite if PG_TEST_EXTRA doesn't contain + 'python'. pytestmark doesn't seem to work in a top-level conftest.py, so + I've made this an autoused fixture instead. + """ + extra_tests = os.getenv("PG_TEST_EXTRA", "").split() + if "python" not in extra_tests: + pytest.skip("Potentially unsafe test 'python' not enabled in PG_TEST_EXTRA") diff --git a/src/test/python/meson.build b/src/test/python/meson.build new file mode 100644 index 0000000000000..7fb0a2a7a7895 --- /dev/null +++ b/src/test/python/meson.build @@ -0,0 +1,47 @@ +# Copyright (c) 2023, PostgreSQL Global Development Group + +subdir('server') + +pytest_env = { + 'with_libcurl': oauth_flow_supported ? 'yes' : 'no', + + # Point to the default database; the tests will create their own databases as + # needed. + 'PGDATABASE': 'postgres', + + # Avoid the need for a Rust compiler on platforms without prebuilt wheels for + # pyca/cryptography. + 'CRYPTOGRAPHY_DONT_BUILD_RUST': '1', +} + +# Some modules (psycopg2) need OpenSSL at compile time; for platforms where we +# might have multiple implementations installed (macOS+brew), try to use the +# same one that libpq is using. +if ssl.found() + pytest_incdir = ssl.get_variable(pkgconfig: 'includedir', default_value: '') + if pytest_incdir != '' + pytest_env += { 'CPPFLAGS': '-I@0@'.format(pytest_incdir) } + endif + + pytest_libdir = ssl.get_variable(pkgconfig: 'libdir', default_value: '') + if pytest_libdir != '' + pytest_env += { 'LDFLAGS': '-L@0@'.format(pytest_libdir) } + endif +endif + +tests += { + 'name': 'python', + 'sd': meson.current_source_dir(), + 'bd': meson.current_build_dir(), + 'pytest': { + 'requirements': meson.current_source_dir() / 'requirements.txt', + 'tests': [ + './client', + './server', + './test_internals.py', + './test_pq3.py', + ], + 'env': pytest_env, + 'test_kwargs': {'priority': 50}, # python tests are slow, start early + }, +} diff --git a/src/test/python/pq3.py b/src/test/python/pq3.py new file mode 100644 index 0000000000000..ef809e288afd3 --- /dev/null +++ b/src/test/python/pq3.py @@ -0,0 +1,740 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import contextlib +import getpass +import io +import os +import platform +import ssl +import sys +import textwrap + +from construct import * + +import tls + + +def protocol(major, minor): + """ + Returns the protocol version, in integer format, corresponding to the given + major and minor version numbers. + """ + return (major << 16) | minor + + +# Startup + +StringList = GreedyRange(NullTerminated(GreedyBytes)) + + +class KeyValueAdapter(Adapter): + """ + Turns a key-value store into a null-terminated list of null-terminated + strings, as presented on the wire in the startup packet. + """ + + def _encode(self, obj, context, path): + if isinstance(obj, list): + return obj + + l = [] + + for k, v in obj.items(): + if isinstance(k, str): + k = k.encode("utf-8") + l.append(k) + + if isinstance(v, str): + v = v.encode("utf-8") + l.append(v) + + l.append(b"") + return l + + def _decode(self, obj, context, path): + # TODO: turn a list back into a dict + return obj + + +KeyValues = KeyValueAdapter(StringList) + +_startup_payload = Switch( + this.proto, + { + protocol(3, 0): KeyValues, + }, + default=GreedyBytes, +) + + +def _default_protocol(this): + try: + if isinstance(this.payload, (list, dict)): + return protocol(3, 0) + except AttributeError: + pass # no payload passed during build + + return 0 + + +def _startup_payload_len(this): + """ + The payload field has a fixed size based on the length of the packet. But + if the caller hasn't supplied an explicit length at build time, we have to + build the payload to figure out how long it is, which requires us to know + the length first... This function exists solely to break the cycle. + """ + assert this._building, "_startup_payload_len() cannot be called during parsing" + + try: + payload = this.payload + except AttributeError: + return 0 # no payload + + if isinstance(payload, bytes): + # already serialized; just use the given length + return len(payload) + + try: + proto = this.proto + except AttributeError: + proto = _default_protocol(this) + + data = _startup_payload.build(payload, proto=proto) + return len(data) + + +Startup = Struct( + "len" / Default(Int32sb, lambda this: _startup_payload_len(this) + 8), + "proto" / Default(Hex(Int32sb), _default_protocol), + "payload" / FixedSized(this.len - 8, Default(_startup_payload, b"")), +) + +# Pq3 + + +# Adapted from construct.core.EnumIntegerString +class EnumNamedByte: + def __init__(self, val, name): + self._val = val + self._name = name + + def __int__(self): + return ord(self._val) + + def __str__(self): + return "(enum) %s %r" % (self._name, self._val) + + def __repr__(self): + return "EnumNamedByte(%r)" % self._val + + def __eq__(self, other): + if isinstance(other, EnumNamedByte): + other = other._val + if not isinstance(other, bytes): + return NotImplemented + + return self._val == other + + def __hash__(self): + return hash(self._val) + + +# Adapted from construct.core.Enum +class ByteEnum(Adapter): + def __init__(self, **mapping): + super(ByteEnum, self).__init__(Byte) + self.namemapping = {k: EnumNamedByte(v, k) for k, v in mapping.items()} + self.decmapping = {v: EnumNamedByte(v, k) for k, v in mapping.items()} + + def __getattr__(self, name): + if name in self.namemapping: + return self.decmapping[self.namemapping[name]] + raise AttributeError + + def _decode(self, obj, context, path): + b = bytes([obj]) + try: + return self.decmapping[b] + except KeyError: + return EnumNamedByte(b, "(unknown)") + + def _encode(self, obj, context, path): + if isinstance(obj, int): + return obj + elif isinstance(obj, bytes): + return ord(obj) + return int(obj) + + +types = ByteEnum( + ErrorResponse=b"E", + ReadyForQuery=b"Z", + Query=b"Q", + EmptyQueryResponse=b"I", + AuthnRequest=b"R", + PasswordMessage=b"p", + BackendKeyData=b"K", + CommandComplete=b"C", + ParameterStatus=b"S", + DataRow=b"D", + Terminate=b"X", +) + + +authn = Enum( + Int32ub, + OK=0, + SASL=10, + SASLContinue=11, + SASLFinal=12, +) + + +_authn_body = Switch( + this.type, + { + authn.OK: Terminated, + authn.SASL: StringList, + }, + default=GreedyBytes, +) + + +def _data_len(this): + assert this._building, "_data_len() cannot be called during parsing" + + if not hasattr(this, "data") or this.data is None: + return -1 + + return len(this.data) + + +# The protocol reuses the PasswordMessage for several authentication response +# types, and there's no good way to figure out which is which without keeping +# state for the entire stream. So this is a separate Construct that can be +# explicitly parsed/built by code that knows it's needed. +SASLInitialResponse = Struct( + "name" / NullTerminated(GreedyBytes), + "len" / Default(Int32sb, lambda this: _data_len(this)), + "data" + / IfThenElse( + # Allow tests to explicitly pass an incorrect length during testing, by + # not enforcing a FixedSized during build. (The len calculation above + # defaults to the correct size.) + this._building, + Optional(GreedyBytes), + If(this.len != -1, Default(FixedSized(this.len, GreedyBytes), b"")), + ), + Terminated, # make sure the entire response is consumed +) + + +_column = FocusedSeq( + "data", + "len" / Default(Int32sb, lambda this: _data_len(this)), + "data" / If(this.len != -1, FixedSized(this.len, GreedyBytes)), +) + + +_payload_map = { + types.ErrorResponse: Struct("fields" / StringList), + types.ReadyForQuery: Struct("status" / Bytes(1)), + types.Query: Struct("query" / NullTerminated(GreedyBytes)), + types.EmptyQueryResponse: Terminated, + types.AuthnRequest: Struct("type" / authn, "body" / Default(_authn_body, b"")), + types.BackendKeyData: Struct("pid" / Int32ub, "key" / Hex(Int32ub)), + types.CommandComplete: Struct("tag" / NullTerminated(GreedyBytes)), + types.ParameterStatus: Struct( + "name" / NullTerminated(GreedyBytes), "value" / NullTerminated(GreedyBytes) + ), + types.DataRow: Struct("columns" / Default(PrefixedArray(Int16sb, _column), b"")), + types.Terminate: Terminated, +} + + +_payload = FocusedSeq( + "_payload", + "_payload" + / Switch( + this._.type, + _payload_map, + default=GreedyBytes, + ), + Terminated, # make sure every payload consumes the entire packet +) + + +def _payload_len(this): + """ + See _startup_payload_len() for an explanation. + """ + assert this._building, "_payload_len() cannot be called during parsing" + + try: + payload = this.payload + except AttributeError: + return 0 # no payload + + if isinstance(payload, bytes): + # already serialized; just use the given length + return len(payload) + + data = _payload.build(payload, type=this.type) + return len(data) + + +Pq3 = Struct( + "type" / types, + "len" / Default(Int32ub, lambda this: _payload_len(this) + 4), + "payload" + / IfThenElse( + # Allow tests to explicitly pass an incorrect length during testing, by + # not enforcing a FixedSized during build. (The len calculation above + # defaults to the correct size.) + this._building, + Optional(_payload), + FixedSized(this.len - 4, Default(_payload, b"")), + ), +) + + +# Environment + + +def pghost(): + return os.environ.get("PGHOST", default="localhost") + + +def pgport(): + return int(os.environ.get("PGPORT", default=5432)) + + +def pguser(): + try: + return os.environ["PGUSER"] + except KeyError: + if platform.system() == "Windows": + # libpq defaults to GetUserName() on Windows. + return os.getlogin() + return getpass.getuser() + + +def pgdatabase(): + return os.environ.get("PGDATABASE", default="postgres") + + +# Connections + + +def _hexdump_translation_map(): + """ + For hexdumps. Translates any unprintable or non-ASCII bytes into '.'. + """ + input = bytearray() + + for i in range(128): + c = chr(i) + + if not c.isprintable(): + input += bytes([i]) + + input += bytes(range(128, 256)) + + return bytes.maketrans(input, b"." * len(input)) + + +class _DebugStream(object): + """ + Wraps a file-like object and adds hexdumps of the read and write data. Call + end_packet() on a _DebugStream to write the accumulated hexdumps to the + output stream, along with the packet that was sent. + """ + + _translation_map = _hexdump_translation_map() + + def __init__(self, stream, out=sys.stdout): + """ + Creates a new _DebugStream wrapping the given stream (which must have + been created by wrap()). All attributes not provided by the _DebugStream + are delegated to the wrapped stream. out is the text stream to which + hexdumps are written. + """ + self.raw = stream + self._out = out + self._rbuf = io.BytesIO() + self._wbuf = io.BytesIO() + + def __getattr__(self, name): + return getattr(self.raw, name) + + def __setattr__(self, name, value): + if name in ("raw", "_out", "_rbuf", "_wbuf"): + return object.__setattr__(self, name, value) + + setattr(self.raw, name, value) + + def read(self, *args, **kwargs): + buf = self.raw.read(*args, **kwargs) + + self._rbuf.write(buf) + return buf + + def write(self, b): + self._wbuf.write(b) + return self.raw.write(b) + + def recv(self, *args): + buf = self.raw.recv(*args) + + self._rbuf.write(buf) + return buf + + def _flush(self, buf, prefix): + width = 16 + hexwidth = width * 3 - 1 + + count = 0 + buf.seek(0) + + while True: + line = buf.read(16) + + if not line: + if count: + self._out.write("\n") # separate the output block with a newline + return + + self._out.write("%s %04X:\t" % (prefix, count)) + self._out.write("%*s\t" % (-hexwidth, line.hex(" "))) + self._out.write(line.translate(self._translation_map).decode("ascii")) + self._out.write("\n") + + count += 16 + + def print_debug(self, obj, *, prefix=""): + contents = "" + if obj is not None: + contents = str(obj) + + for line in contents.splitlines(): + self._out.write("%s%s\n" % (prefix, line)) + + self._out.write("\n") + + def flush_debug(self, *, prefix=""): + self._flush(self._rbuf, prefix + "<") + self._rbuf = io.BytesIO() + + self._flush(self._wbuf, prefix + ">") + self._wbuf = io.BytesIO() + + def end_packet(self, pkt, *, read=False, prefix="", indent=" "): + """ + Marks the end of a logical "packet" of data. A string representation of + pkt will be printed, and the debug buffers will be flushed with an + indent. All lines can be optionally prefixed. + + If read is True, the packet representation is written after the debug + buffers; otherwise the default of False (meaning write) causes the + packet representation to be dumped first. This is meant to capture the + logical flow of layer translation. + """ + write = not read + + if write: + self.print_debug(pkt, prefix=prefix + "> ") + + self.flush_debug(prefix=prefix + indent) + + if read: + self.print_debug(pkt, prefix=prefix + "< ") + + +@contextlib.contextmanager +def wrap(socket, *, debug_stream=None): + """ + Transforms a raw socket into a connection that can be used for Construct + building and parsing. The return value is a context manager and can be used + in a with statement. + """ + # It is critical that buffering be disabled here, so that we can still + # manipulate the raw socket without desyncing the stream. + with socket.makefile("rwb", buffering=0) as sfile: + # Expose the original socket's recv() on the SocketIO object we return. + def recv(self, *args): + return socket.recv(*args) + + sfile.recv = recv.__get__(sfile) + + conn = sfile + if debug_stream: + conn = _DebugStream(conn, debug_stream) + + try: + yield conn + finally: + if debug_stream: + conn.flush_debug(prefix="? ") + + +def _send(stream, cls, obj): + debugging = hasattr(stream, "flush_debug") + out = io.BytesIO() + + # Ideally we would build directly to the passed stream, but because we need + # to reparse the generated output for the debugging case, build to an + # intermediate BytesIO and send it instead. + cls.build_stream(obj, out) + buf = out.getvalue() + + stream.write(buf) + if debugging: + pkt = cls.parse(buf) + stream.end_packet(pkt) + + stream.flush() + + +def send(stream, packet_type, payload_data=None, **payloadkw): + """ + Sends a packet on the given pq3 connection. type is the pq3.types member + that should be assigned to the packet. If payload_data is given, it will be + used as the packet payload; otherwise the key/value pairs in payloadkw will + be the payload contents. + """ + data = payloadkw + + if payload_data is not None: + if payloadkw: + raise ValueError( + "payload_data and payload keywords may not be used simultaneously" + ) + + data = payload_data + + _send(stream, Pq3, dict(type=packet_type, payload=data)) + + +def send_startup(stream, proto=None, **kwargs): + """ + Sends a startup packet on the given pq3 connection. In most cases you should + use the handshake functions instead, which will do this for you. + + By default, a protocol version 3 packet will be sent. This can be overridden + with the proto parameter. + """ + pkt = {} + + if proto is not None: + pkt["proto"] = proto + if kwargs: + pkt["payload"] = kwargs + + _send(stream, Startup, pkt) + + +def recv1(stream, *, cls=Pq3): + """ + Receives a single pq3 packet from the given stream and returns it. + """ + resp = cls.parse_stream(stream) + + debugging = hasattr(stream, "flush_debug") + if debugging: + stream.end_packet(resp, read=True) + + return resp + + +def handshake(stream, **kwargs): + """ + Performs a libpq v3 startup handshake. kwargs should contain the key/value + parameters to send to the server in the startup packet. + """ + # Send our startup parameters. + send_startup(stream, **kwargs) + + # Receive and dump packets until the server indicates it's ready for our + # first query. + while True: + resp = recv1(stream) + if resp is None: + raise RuntimeError("server closed connection during handshake") + + if resp.type == types.ReadyForQuery: + return + elif resp.type == types.ErrorResponse: + raise RuntimeError( + f"received error response from peer: {resp.payload.fields!r}" + ) + + +# TLS + + +class _TLSStream(object): + """ + A file-like object that performs TLS encryption/decryption on a wrapped + stream. Differs from ssl.SSLSocket in that we have full visibility and + control over the TLS layer. + """ + + def __init__(self, stream, context): + self._stream = stream + self._debugging = hasattr(stream, "flush_debug") + + self._in = ssl.MemoryBIO() + self._out = ssl.MemoryBIO() + self._ssl = context.wrap_bio(self._in, self._out) + + def handshake(self): + try: + self._pump(lambda: self._ssl.do_handshake()) + finally: + self._flush_debug(prefix="? ") + + def read(self, *args): + return self._pump(lambda: self._ssl.read(*args)) + + def write(self, *args): + return self._pump(lambda: self._ssl.write(*args)) + + def _decode(self, buf): + """ + Attempts to decode a buffer of TLS data into a packet representation + that can be printed. + + TODO: handle buffers (and record fragments) that don't align with packet + boundaries. + """ + end = len(buf) + bio = io.BytesIO(buf) + + ret = io.StringIO() + + while bio.tell() < end: + record = tls.Plaintext.parse_stream(bio) + + if ret.tell() > 0: + ret.write("\n") + ret.write("[Record] ") + ret.write(str(record)) + ret.write("\n") + + if record.type == tls.ContentType.handshake: + record_cls = tls.Handshake + else: + continue + + innerlen = len(record.fragment) + inner = io.BytesIO(record.fragment) + + while inner.tell() < innerlen: + msg = record_cls.parse_stream(inner) + + indented = "[Message] " + str(msg) + indented = textwrap.indent(indented, " ") + + ret.write("\n") + ret.write(indented) + ret.write("\n") + + return ret.getvalue() + + def flush(self): + if not self._out.pending: + self._stream.flush() + return + + buf = self._out.read() + self._stream.write(buf) + + if self._debugging: + pkt = self._decode(buf) + self._stream.end_packet(pkt, prefix=" ") + + self._stream.flush() + + def _pump(self, operation): + while True: + try: + return operation() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: + want = e + self._read_write(want) + + def _recv(self, maxsize): + buf = self._stream.recv(4096) + if not buf: + self._in.write_eof() + return + + self._in.write(buf) + + if not self._debugging: + return + + pkt = self._decode(buf) + self._stream.end_packet(pkt, read=True, prefix=" ") + + def _read_write(self, want): + # XXX This needs work. So many corner cases yet to handle. For one, + # doing blocking writes in flush may lead to distributed deadlock if the + # peer is already blocking on its writes. + + if isinstance(want, ssl.SSLWantWriteError): + assert self._out.pending, "SSL backend wants write without data" + + self.flush() + + if isinstance(want, ssl.SSLWantReadError): + self._recv(4096) + + def _flush_debug(self, prefix): + if not self._debugging: + return + + self._stream.flush_debug(prefix=prefix) + + +@contextlib.contextmanager +def tls_handshake(stream, context): + """ + Performs a TLS handshake over the given stream (which must have been created + via a call to wrap()), and returns a new stream which transparently tunnels + data over the TLS connection. + + If the passed stream has debugging enabled, the returned stream will also + have debugging, using the same output IO. + """ + debugging = hasattr(stream, "flush_debug") + + # Send our startup parameters. + send_startup(stream, proto=protocol(1234, 5679)) + + # Look at the SSL response. + resp = stream.read(1) + if debugging: + stream.flush_debug(prefix=" ") + + if resp == b"N": + raise RuntimeError("server does not support SSLRequest") + if resp != b"S": + raise RuntimeError(f"unexpected response of type {resp!r} during TLS startup") + + tls = _TLSStream(stream, context) + tls.handshake() + + if debugging: + tls = _DebugStream(tls, stream._out) + + try: + yield tls + # TODO: teardown/unwrap the connection? + finally: + if debugging: + tls.flush_debug(prefix="? ") diff --git a/src/test/python/pytest.ini b/src/test/python/pytest.ini new file mode 100644 index 0000000000000..ab7a6e7fb9660 --- /dev/null +++ b/src/test/python/pytest.ini @@ -0,0 +1,4 @@ +[pytest] + +markers = + slow: mark test as slow diff --git a/src/test/python/requirements.txt b/src/test/python/requirements.txt new file mode 100644 index 0000000000000..575e4354c8ba0 --- /dev/null +++ b/src/test/python/requirements.txt @@ -0,0 +1,11 @@ +black~=25.0 +# cryptography 35.x and later add many platform/toolchain restrictions, beware +cryptography>=3.4.8 +# TODO: figure out why 2.10.70 broke things +# (probably https://github.com/construct/construct/pull/1015) +construct==2.10.69 +isort~=5.6 +# TODO: update to psycopg[c] 3.1 +psycopg2~=2.9.7 +pytest~=7.3 +pytest-asyncio~=0.21.0 diff --git a/src/test/python/server/__init__.py b/src/test/python/server/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/test/python/server/conftest.py b/src/test/python/server/conftest.py new file mode 100644 index 0000000000000..42af80c73eedf --- /dev/null +++ b/src/test/python/server/conftest.py @@ -0,0 +1,141 @@ +# +# Portions Copyright 2021 VMware, Inc. +# Portions Copyright 2023 Timescale, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import collections +import contextlib +import os +import shutil +import socket +import subprocess +import sys + +import pytest + +import pq3 + +BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls + + +def cleanup_prior_instance(datadir): + """ + Clean up an existing data directory, but make sure it actually looks like a + data directory first. (Empty folders will remain untouched, since initdb can + populate them.) + """ + required_entries = set(["base", "PG_VERSION", "postgresql.conf"]) + empty = True + + try: + with os.scandir(datadir) as entries: + for e in entries: + empty = False + required_entries.discard(e.name) + + except FileNotFoundError: + return # nothing to clean up + + if empty: + return # initdb can handle an empty datadir + + if required_entries: + pytest.fail( + f"--temp-instance directory \"{datadir}\" is not empty and doesn't look like a data directory (missing {', '.join(required_entries)})" + ) + + # Okay, seems safe enough now. + shutil.rmtree(datadir) + + +@pytest.fixture(scope="session") +def postgres_instance(pytestconfig, unused_tcp_port_factory): + """ + If --temp-instance has been passed to pytest, this fixture runs a temporary + Postgres instance on an available port. Otherwise, the fixture will attempt + to contact a running Postgres server on (PGHOST, PGPORT); dependent tests + will be skipped if the connection fails. + + Yields a (host, port) tuple for connecting to the server. + """ + PGInstance = collections.namedtuple("PGInstance", ["addr", "temporary"]) + + datadir = pytestconfig.getoption("temp_instance") + if datadir: + # We were told to create a temporary instance. Use pg_ctl to set it up + # on an unused port. + cleanup_prior_instance(datadir) + subprocess.run(["pg_ctl", "-D", datadir, "init"], check=True) + + # The CI looks for *.log files to upload, so the file name here isn't + # completely arbitrary. + log = os.path.join(datadir, "postmaster.log") + port = unused_tcp_port_factory() + + subprocess.run( + [ + "pg_ctl", + "-D", + datadir, + "-l", + log, + "-o", + " ".join( + [ + f"-c port={port}", + "-c listen_addresses=localhost", + "-c log_connections=on", + "-c session_preload_libraries=oauthtest", + "-c oauth_validator_libraries=oauthtest", + ] + ), + "start", + ], + check=True, + ) + + yield ("localhost", port) + + subprocess.run(["pg_ctl", "-D", datadir, "stop"], check=True) + + else: + # Try to contact an already running server; skip the suite if we can't + # find one. + addr = (pq3.pghost(), pq3.pgport()) + + try: + with socket.create_connection(addr, timeout=BLOCKING_TIMEOUT): + pass + except ConnectionError as e: + pytest.skip(f"unable to connect to Postgres server at {addr}: {e}") + + yield addr + + +@pytest.fixture +def connect(postgres_instance): + """ + A factory fixture that, when called, returns a socket connected to a + Postgres server, wrapped in a pq3 connection. Dependent tests will be + skipped if no server is available. + """ + addr = postgres_instance + + # Set up an ExitStack to handle safe cleanup of all of the moving pieces. + with contextlib.ExitStack() as stack: + + def conn_factory(): + sock = socket.create_connection(addr, timeout=BLOCKING_TIMEOUT) + + # Have ExitStack close our socket. + stack.enter_context(sock) + + # Wrap the connection in a pq3 layer and have ExitStack clean it up + # too. + wrap_ctx = pq3.wrap(sock, debug_stream=sys.stdout) + conn = stack.enter_context(wrap_ctx) + + return conn + + yield conn_factory diff --git a/src/test/python/server/meson.build b/src/test/python/server/meson.build new file mode 100644 index 0000000000000..85534b9cc99fb --- /dev/null +++ b/src/test/python/server/meson.build @@ -0,0 +1,18 @@ +# Copyright (c) 2024, PostgreSQL Global Development Group + +oauthtest_sources = files( + 'oauthtest.c', +) + +if host_system == 'windows' + oauthtest_sources += rc_lib_gen.process(win32ver_rc, extra_args: [ + '--NAME', 'oauthtest', + '--FILEDESC', 'passthrough module to validate OAuth tests', + ]) +endif + +oauthtest = shared_module('oauthtest', + oauthtest_sources, + kwargs: pg_test_mod_args, +) +test_install_libs += oauthtest diff --git a/src/test/python/server/oauthtest.c b/src/test/python/server/oauthtest.c new file mode 100644 index 0000000000000..0166c81468952 --- /dev/null +++ b/src/test/python/server/oauthtest.c @@ -0,0 +1,119 @@ +/*------------------------------------------------------------------------- + * + * oauthtest.c + * Test module for serverside OAuth token validation callbacks + * + * Portions Copyright (c) 1996-2025, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/test/python/server/oauthtest.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include "fmgr.h" +#include "libpq/oauth.h" +#include "utils/guc.h" +#include "utils/memutils.h" + +PG_MODULE_MAGIC; + +static void test_startup(ValidatorModuleState *state); +static void test_shutdown(ValidatorModuleState *state); +static bool test_validate(const ValidatorModuleState *state, + const char *token, + const char *role, + ValidatorModuleResult *result); + +static const OAuthValidatorCallbacks callbacks = { + PG_OAUTH_VALIDATOR_MAGIC, + + .startup_cb = test_startup, + .shutdown_cb = test_shutdown, + .validate_cb = test_validate, +}; + +static char *expected_bearer = ""; +static bool set_authn_id = false; +static char *authn_id = ""; +static bool reflect_role = false; + +void +_PG_init(void) +{ + DefineCustomStringVariable("oauthtest.expected_bearer", + "Expected Bearer token for future connections", + NULL, + &expected_bearer, + "", + PGC_SIGHUP, + 0, + NULL, NULL, NULL); + + DefineCustomBoolVariable("oauthtest.set_authn_id", + "Whether to set an authenticated identity", + NULL, + &set_authn_id, + false, + PGC_SIGHUP, + 0, + NULL, NULL, NULL); + DefineCustomStringVariable("oauthtest.authn_id", + "Authenticated identity to use for future connections", + NULL, + &authn_id, + "", + PGC_SIGHUP, + 0, + NULL, NULL, NULL); + + DefineCustomBoolVariable("oauthtest.reflect_role", + "Ignore the bearer token; use the requested role as the authn_id", + NULL, + &reflect_role, + false, + PGC_SIGHUP, + 0, + NULL, NULL, NULL); + + MarkGUCPrefixReserved("oauthtest"); +} + +const OAuthValidatorCallbacks * +_PG_oauth_validator_module_init(void) +{ + return &callbacks; +} + +static void +test_startup(ValidatorModuleState *state) +{ +} + +static void +test_shutdown(ValidatorModuleState *state) +{ +} + +static bool +test_validate(const ValidatorModuleState *state, + const char *token, const char *role, + ValidatorModuleResult *res) +{ + if (reflect_role) + { + res->authorized = true; + res->authn_id = pstrdup(role); + } + else + { + if (*expected_bearer && strcmp(token, expected_bearer) == 0) + res->authorized = true; + if (set_authn_id) + res->authn_id = pstrdup(authn_id); + } + + return true; +} diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py new file mode 100644 index 0000000000000..9e9fca650a010 --- /dev/null +++ b/src/test/python/server/test_oauth.py @@ -0,0 +1,1080 @@ +# +# Copyright 2021 VMware, Inc. +# Portions Copyright 2023 Timescale, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import base64 +import contextlib +import json +import os +import pathlib +import platform +import secrets +import shlex +import shutil +import socket +import struct +from multiprocessing import shared_memory + +import psycopg2 +import pytest +from construct import Container +from psycopg2 import sql + +import pq3 + +from .conftest import BLOCKING_TIMEOUT + +MAX_SASL_MESSAGE_LENGTH = 65535 + +INVALID_AUTHORIZATION_ERRCODE = b"28000" +PROTOCOL_VIOLATION_ERRCODE = b"08P01" +FEATURE_NOT_SUPPORTED_ERRCODE = b"0A000" + +SHARED_MEM_NAME = "oauth-pytest" +MAX_UINT16 = 2**16 - 1 + + +@contextlib.contextmanager +def prepend_file(path, lines, *, suffix=".bak"): + """ + A context manager that prepends a file on disk with the desired lines of + text. When the context manager is exited, the file will be restored to its + original contents. + """ + # First make a backup of the original file. + bak = path + suffix + shutil.copy2(path, bak) + + try: + # Write the new lines, followed by the original file content. + with open(path, "w") as new, open(bak, "r") as orig: + new.writelines(lines) + shutil.copyfileobj(orig, new) + + # Return control to the calling code. + yield + + finally: + # Put the backup back into place. + os.replace(bak, path) + + +@pytest.fixture(scope="module") +def oauth_ctx(postgres_instance): + """ + Creates a database and user that use the oauth auth method. The context + object contains the dbname and user attributes as strings to be used during + connection, as well as the issuer and scope that have been set in the HBA + configuration. + + This fixture assumes that the standard PG* environment variables point to a + server running on a local machine, and that the PGUSER has rights to create + databases and roles. + """ + id = secrets.token_hex(4) + + class Context: + dbname = "oauth_test_" + id + + user = "oauth_user_" + id + punct_user = "oauth_\"'? ;&!_user_" + id # username w/ punctuation + map_user = "oauth_map_user_" + id + authz_user = "oauth_authz_user_" + id + + issuer = "https://example.com/" + id + scope = "openid " + id + + ctx = Context() + hba_lines = [ + f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n', + f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" delegate_ident_mapping=1\n', + f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n', + ] + ident_lines = [r"oauth /^(.*)@example\.com$ \1"] + + if platform.system() == "Windows": + # XXX why is 'samehost' not behaving as expected on Windows? + for l in list(hba_lines): + hba_lines.append(l.replace("samehost", "::1/128")) + + host, port = postgres_instance + conn = psycopg2.connect(host=host, port=port) + conn.autocommit = True + + with contextlib.closing(conn): + c = conn.cursor() + + # Create our roles and database. + user = sql.Identifier(ctx.user) + punct_user = sql.Identifier(ctx.punct_user) + map_user = sql.Identifier(ctx.map_user) + authz_user = sql.Identifier(ctx.authz_user) + dbname = sql.Identifier(ctx.dbname) + + c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(user)) + c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(punct_user)) + c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(map_user)) + c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user)) + c.execute(sql.SQL("CREATE DATABASE {};").format(dbname)) + + # Replace pg_hba and pg_ident. + c.execute("SHOW hba_file;") + hba = c.fetchone()[0] + + c.execute("SHOW ident_file;") + ident = c.fetchone()[0] + + with prepend_file(hba, hba_lines), prepend_file(ident, ident_lines): + c.execute("SELECT pg_reload_conf();") + + # Use the new database and user. + yield ctx + + # Put things back the way they were. + c.execute("SELECT pg_reload_conf();") + + c.execute(sql.SQL("DROP DATABASE {};").format(dbname)) + c.execute(sql.SQL("DROP ROLE {};").format(authz_user)) + c.execute(sql.SQL("DROP ROLE {};").format(map_user)) + c.execute(sql.SQL("DROP ROLE {};").format(punct_user)) + c.execute(sql.SQL("DROP ROLE {};").format(user)) + + +@pytest.fixture() +def conn(oauth_ctx, connect): + """ + A convenience wrapper for connect(). The main purpose of this fixture is to + make sure oauth_ctx runs its setup code before the connection is made. + """ + return connect() + + +def bearer_token(*, size=16): + """ + Generates a Bearer token using secrets.token_urlsafe(). The generated token + size in bytes may be specified; if unset, a small 16-byte token will be + generated. + """ + + if size % 4: + raise ValueError(f"requested token size {size} is not a multiple of 4") + + token = secrets.token_urlsafe(size // 4 * 3) + assert len(token) == size + + return token + + +def begin_oauth_handshake(conn, oauth_ctx, *, user=None): + if user is None: + user = oauth_ctx.authz_user + + pq3.send_startup(conn, user=user, database=oauth_ctx.dbname) + + resp = pq3.recv1(conn) + assert resp.type == pq3.types.AuthnRequest + + # The server should advertise exactly one mechanism. + assert resp.payload.type == pq3.authn.SASL + assert resp.payload.body == [b"OAUTHBEARER", b""] + + +def send_initial_response(conn, *, auth=None, bearer=None): + """ + Sends the OAUTHBEARER initial response on the connection, using the given + bearer token. Alternatively to a bearer token, the initial response's auth + field may be explicitly specified to test corner cases. + """ + if bearer is not None and auth is not None: + raise ValueError("exactly one of the auth and bearer kwargs must be set") + + if bearer is not None: + auth = b"Bearer " + bearer + + if auth is None: + raise ValueError("exactly one of the auth and bearer kwargs must be set") + + initial = pq3.SASLInitialResponse.build( + dict( + name=b"OAUTHBEARER", + data=b"n,,\x01auth=" + auth + b"\x01\x01", + ) + ) + pq3.send(conn, pq3.types.PasswordMessage, initial) + + +def expect_handshake_success(conn): + """ + Validates that the server responds with an AuthnOK message, and then drains + the connection until a ReadyForQuery message is received. + """ + resp = pq3.recv1(conn) + + assert resp.type == pq3.types.AuthnRequest + assert resp.payload.type == pq3.authn.OK + assert not resp.payload.body + + receive_until(conn, pq3.types.ReadyForQuery) + + +def expect_handshake_failure(conn, oauth_ctx): + """ + Performs the OAUTHBEARER SASL failure "handshake" and validates the server's + side of the conversation, including the final ErrorResponse. + """ + + # We expect a discovery "challenge" back from the server before the authn + # failure message. + resp = pq3.recv1(conn) + assert resp.type == pq3.types.AuthnRequest + + req = resp.payload + assert req.type == pq3.authn.SASLContinue + + body = json.loads(req.body) + assert body["status"] == "invalid_token" + assert body["scope"] == oauth_ctx.scope + + expected_config = oauth_ctx.issuer + "/.well-known/openid-configuration" + assert body["openid-configuration"] == expected_config + + # Send the dummy response to complete the failed handshake. + pq3.send(conn, pq3.types.PasswordMessage, b"\x01") + resp = pq3.recv1(conn) + + err = ExpectedError(INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed") + err.match(resp) + + +def receive_until(conn, type): + """ + receive_until pulls packets off the pq3 connection until a packet with the + desired type is found, or an error response is received. + """ + while True: + pkt = pq3.recv1(conn) + + if pkt.type == type: + return pkt + elif pkt.type == pq3.types.ErrorResponse: + raise RuntimeError( + f"received error response from peer: {pkt.payload.fields!r}" + ) + + +@pytest.fixture() +def setup_validator(postgres_instance): + """ + A per-test fixture that sets up the test validator with expected behavior. + The setting will be reverted during teardown. + """ + host, port = postgres_instance + conn = psycopg2.connect(host=host, port=port) + conn.autocommit = True + + with contextlib.closing(conn): + c = conn.cursor() + prev = dict() + + def setter(**gucs): + for guc, val in gucs.items(): + # Save the previous value. + c.execute(sql.SQL("SHOW oauthtest.{};").format(sql.Identifier(guc))) + prev[guc] = c.fetchone()[0] + + c.execute( + sql.SQL("ALTER SYSTEM SET oauthtest.{} TO %s;").format( + sql.Identifier(guc) + ), + (val,), + ) + c.execute("SELECT pg_reload_conf();") + + yield setter + + # Restore the previous values. + for guc, val in prev.items(): + c.execute( + sql.SQL("ALTER SYSTEM SET oauthtest.{} TO %s;").format( + sql.Identifier(guc) + ), + (val,), + ) + c.execute("SELECT pg_reload_conf();") + + +@pytest.mark.parametrize("token_len", [16, 1024, 4096]) +@pytest.mark.parametrize( + "auth_prefix", + [ + b"Bearer ", + b"bearer ", + b"Bearer ", + ], +) +def test_oauth(setup_validator, connect, oauth_ctx, auth_prefix, token_len): + # Generate our bearer token with the desired length. + token = bearer_token(size=token_len) + setup_validator(expected_bearer=token) + + conn = connect() + begin_oauth_handshake(conn, oauth_ctx) + + auth = auth_prefix + token.encode("ascii") + send_initial_response(conn, auth=auth) + expect_handshake_success(conn) + + # Make sure that the server has not set an authenticated ID. + pq3.send(conn, pq3.types.Query, query=b"SELECT system_user;") + resp = receive_until(conn, pq3.types.DataRow) + + row = resp.payload + assert row.columns == [None] + + +@pytest.mark.parametrize( + "token_value", + [ + "abcdzA==", + "123456M=", + "x-._~+/x", + ], +) +def test_oauth_bearer_corner_cases(setup_validator, connect, oauth_ctx, token_value): + setup_validator(expected_bearer=token_value) + + conn = connect() + begin_oauth_handshake(conn, oauth_ctx) + + send_initial_response(conn, bearer=token_value.encode("ascii")) + + expect_handshake_success(conn) + + +@pytest.mark.parametrize( + "user,authn_id,should_succeed", + [ + pytest.param( + lambda ctx: ctx.user, + lambda ctx: ctx.user, + True, + id="validator authn: succeeds when authn_id == username", + ), + pytest.param( + lambda ctx: ctx.user, + lambda ctx: None, + False, + id="validator authn: fails when authn_id is not set", + ), + pytest.param( + lambda ctx: ctx.user, + lambda ctx: "", + False, + id="validator authn: fails when authn_id is empty", + ), + pytest.param( + lambda ctx: ctx.user, + lambda ctx: ctx.authz_user, + False, + id="validator authn: fails when authn_id != username", + ), + pytest.param( + lambda ctx: ctx.map_user, + lambda ctx: ctx.map_user + "@example.com", + True, + id="validator with map: succeeds when authn_id matches map", + ), + pytest.param( + lambda ctx: ctx.map_user, + lambda ctx: None, + False, + id="validator with map: fails when authn_id is not set", + ), + pytest.param( + lambda ctx: ctx.map_user, + lambda ctx: ctx.map_user + "@example.net", + False, + id="validator with map: fails when authn_id doesn't match map", + ), + pytest.param( + lambda ctx: ctx.authz_user, + lambda ctx: None, + True, + id="validator authz: succeeds with no authn_id", + ), + pytest.param( + lambda ctx: ctx.authz_user, + lambda ctx: "", + True, + id="validator authz: succeeds with empty authn_id", + ), + pytest.param( + lambda ctx: ctx.authz_user, + lambda ctx: "postgres", + True, + id="validator authz: succeeds with basic username", + ), + pytest.param( + lambda ctx: ctx.authz_user, + lambda ctx: "me@example.com", + True, + id="validator authz: succeeds with email address", + ), + ], +) +def test_oauth_authn_id( + setup_validator, connect, oauth_ctx, user, authn_id, should_succeed +): + token = bearer_token() + authn_id = authn_id(oauth_ctx) + + # Set up the validator appropriately. + gucs = dict(expected_bearer=token) + if authn_id is not None: + gucs["set_authn_id"] = True + gucs["authn_id"] = authn_id + setup_validator(**gucs) + + conn = connect() + username = user(oauth_ctx) + begin_oauth_handshake(conn, oauth_ctx, user=username) + send_initial_response(conn, bearer=token.encode("ascii")) + + if not should_succeed: + expect_handshake_failure(conn, oauth_ctx) + return + + expect_handshake_success(conn) + + # Check the reported authn_id. + pq3.send(conn, pq3.types.Query, query=b"SELECT system_user;") + resp = receive_until(conn, pq3.types.DataRow) + + expected = authn_id + if expected is not None: + expected = b"oauth:" + expected.encode("ascii") + + row = resp.payload + assert row.columns == [expected] + + +class ExpectedError(object): + def __init__(self, code, msg=None, detail=None): + self.code = code + self.msg = msg + self.detail = detail + + # Protect against the footgun of an accidental empty string, which will + # "match" anything. If you don't want to match message or detail, just + # don't pass them. + if self.msg == "": + raise ValueError("msg must be non-empty or None") + if self.detail == "": + raise ValueError("detail must be non-empty or None") + + def _getfield(self, resp, type): + """ + Searches an ErrorResponse for a single field of the given type (e.g. + "M", "C", "D") and returns its value. Asserts if it doesn't find exactly + one field. + """ + prefix = type.encode("ascii") + fields = [f for f in resp.payload.fields if f.startswith(prefix)] + + assert len(fields) == 1, f"did not find exactly one {type} field" + return fields[0][1:] # strip off the type byte + + def match(self, resp): + """ + Checks that the given response matches the expected code, message, and + detail (if given). The error code must match exactly. The expected + message and detail must be contained within the actual strings. + """ + assert resp.type == pq3.types.ErrorResponse + + code = self._getfield(resp, "C") + assert code == self.code + + if self.msg: + msg = self._getfield(resp, "M") + expected = self.msg.encode("utf-8") + assert expected in msg + + if self.detail: + detail = self._getfield(resp, "D") + expected = self.detail.encode("utf-8") + assert expected in detail + + +def test_oauth_rejected_bearer(conn, oauth_ctx): + begin_oauth_handshake(conn, oauth_ctx) + + # Send a bearer token that doesn't match what the validator expects. It + # should fail the connection. + send_initial_response(conn, bearer=b"xxxxxx") + + expect_handshake_failure(conn, oauth_ctx) + + +@pytest.mark.parametrize( + "bad_bearer", + [ + b"Bearer ", + b"Bearer a===b", + b"Bearer hello!", + b"Bearer trailingspace ", + b"Bearer trailingtab\t", + b"Bearer me@example.com", + b"Beare abcd", + b" Bearer leadingspace", + b'OAuth realm="Example"', + b"", + ], +) +def test_oauth_invalid_bearer(setup_validator, connect, oauth_ctx, bad_bearer): + # Tell the validator to accept any token. This ensures that the invalid + # bearer tokens are rejected before the validation step. + setup_validator(reflect_role=True) + + conn = connect() + begin_oauth_handshake(conn, oauth_ctx) + send_initial_response(conn, auth=bad_bearer) + + expect_handshake_failure(conn, oauth_ctx) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "resp_type,resp,err", + [ + pytest.param( + None, + None, + None, + marks=pytest.mark.slow, + id="no response (expect timeout)", + ), + pytest.param( + pq3.types.PasswordMessage, + b"hello", + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "did not send a kvsep response", + ), + id="bad dummy response", + ), + pytest.param( + pq3.types.PasswordMessage, + b"\x01\x01", + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "did not send a kvsep response", + ), + id="multiple kvseps", + ), + pytest.param( + pq3.types.Query, + dict(query=b""), + ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"), + id="bad response message type", + ), + ], +) +def test_oauth_bad_response_to_error_challenge(conn, oauth_ctx, resp_type, resp, err): + begin_oauth_handshake(conn, oauth_ctx) + + # Send an empty auth initial response, which will force an authn failure. + send_initial_response(conn, auth=b"") + + # We expect a discovery "challenge" back from the server before the authn + # failure message. + pkt = pq3.recv1(conn) + assert pkt.type == pq3.types.AuthnRequest + + req = pkt.payload + assert req.type == pq3.authn.SASLContinue + + body = json.loads(req.body) + assert body["status"] == "invalid_token" + + if resp_type is None: + # Do not send the dummy response. We should time out and not get a + # response from the server. + with pytest.raises(socket.timeout): + conn.read(1) + + # Done with the test. + return + + # Send the bad response. + pq3.send(conn, resp_type, resp) + + # Make sure the server fails the connection correctly. + pkt = pq3.recv1(conn) + err.match(pkt) + + +@pytest.mark.parametrize( + "type,payload,err", + [ + pytest.param( + pq3.types.ErrorResponse, + dict(fields=[b""]), + ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"), + id="error response in initial message", + ), + pytest.param( + None, + # Sending an actual 65k packet results in ECONNRESET on Windows, and + # it floods the tests' connection log uselessly, so just fake the + # length and send a smaller number of bytes. + dict( + type=pq3.types.PasswordMessage, + len=MAX_SASL_MESSAGE_LENGTH + 1, + payload=b"x" * 512, + ), + ExpectedError( + INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed" + ), + id="overlong initial response data", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"SCRAM-SHA-256")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, "invalid SASL authentication mechanism" + ), + id="bad SASL mechanism selection", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=2, data=b"x")), + ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "insufficient data"), + id="SASL data underflow", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=0, data=b"x")), + ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "invalid message format"), + id="SASL data overflow", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "message is empty", + ), + id="empty", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"n,,\x01auth=\x01\x01\0") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "length does not match input length", + ), + id="contains null byte", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"\x01")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "Unexpected channel-binding flag", # XXX this is a bit strange + ), + id="initial error response", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"p=tls-server-end-point,,\x01") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "server does not support channel binding", + ), + id="uses channel binding", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"x,,\x01")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "Unexpected channel-binding flag", + ), + id="invalid channel binding specifier", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "Comma expected", + ), + id="bad GS2 header: missing channel binding terminator", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,a")), + ExpectedError( + FEATURE_NOT_SUPPORTED_ERRCODE, + "client uses authorization identity", + ), + id="bad GS2 header: authzid in use", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,b,")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "Unexpected attribute", + ), + id="bad GS2 header: extra attribute", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + 'Unexpected attribute "0x00"', # XXX this is a bit strange + ), + id="bad GS2 header: missing authzid terminator", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "Key-value separator expected", + ), + id="missing initial kvsep", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "Key-value separator expected", + ), + id="missing initial kvsep", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"y,,\x01\x01") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "does not contain an auth value", + ), + id="missing auth value: empty key-value list", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com\x01\x01") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "does not contain an auth value", + ), + id="missing auth value: other keys present", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "unterminated key/value pair", + ), + id="missing value terminator", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,\x01")), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "did not contain a final terminator", + ), + id="missing list terminator: empty list", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "did not contain a final terminator", + ), + id="missing list terminator: with auth value", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01\x01blah") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "additional data after the final terminator", + ), + id="additional key after terminator", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict(name=b"OAUTHBEARER", data=b"y,,\x01key\x01\x01") + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "key without a value", + ), + id="key without value", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict( + name=b"OAUTHBEARER", + data=b"y,,\x01auth=Bearer 0\x01auth=Bearer 1\x01\x01", + ) + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "contains multiple auth values", + ), + id="multiple auth values", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict( + name=b"OAUTHBEARER", + data=b"y,,\x01=\x01\x01", + ) + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "empty key name", + ), + id="empty key", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict( + name=b"OAUTHBEARER", + data=b"y,,\x01my key= \x01\x01", + ) + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "invalid key name", + ), + id="whitespace in key name", + ), + pytest.param( + pq3.types.PasswordMessage, + pq3.SASLInitialResponse.build( + dict( + name=b"OAUTHBEARER", + data=b"y,,\x01key=a\x05b\x01\x01", + ) + ), + ExpectedError( + PROTOCOL_VIOLATION_ERRCODE, + "malformed OAUTHBEARER message", + "invalid value", + ), + id="junk in value", + ), + ], +) +def test_oauth_bad_initial_response(conn, oauth_ctx, type, payload, err): + begin_oauth_handshake(conn, oauth_ctx) + + # The server expects a SASL response; give it something else instead. + if type is not None: + # Build a new packet of the desired type. + if not isinstance(payload, dict): + payload = dict(payload_data=payload) + pq3.send(conn, type, **payload) + else: + # The test has a custom packet to send. (The only reason to do this is + # if the packet is corrupt or otherwise unbuildable/unparsable, so we + # don't use the standard pq3.send().) + conn.write(pq3.Pq3.build(payload)) + conn.end_packet(Container(payload)) + + resp = pq3.recv1(conn) + err.match(resp) + + +def test_oauth_empty_initial_response(setup_validator, connect, oauth_ctx): + token = bearer_token() + setup_validator(expected_bearer=token) + + conn = connect() + begin_oauth_handshake(conn, oauth_ctx) + + # Send an initial response without data. + initial = pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER")) + pq3.send(conn, pq3.types.PasswordMessage, initial) + + # The server should respond with an empty challenge so we can send the data + # it wants. + pkt = pq3.recv1(conn) + + assert pkt.type == pq3.types.AuthnRequest + assert pkt.payload.type == pq3.authn.SASLContinue + assert not pkt.payload.body + + # Now send the initial data. + data = b"n,,\x01auth=Bearer " + token.encode("ascii") + b"\x01\x01" + pq3.send(conn, pq3.types.PasswordMessage, data) + + # Server should now complete the handshake. + expect_handshake_success(conn) + + +# TODO: see if there's a way to test this easily after the API switch +def xtest_oauth_no_validator(setup_validator, oauth_ctx, connect): + # Clear out our validator command, then establish a new connection. + set_validator("") + conn = connect() + + begin_oauth_handshake(conn, oauth_ctx) + send_initial_response(conn, bearer=bearer_token()) + + # The server should fail the connection. + expect_handshake_failure(conn, oauth_ctx) + + +@pytest.mark.parametrize( + "user", + [ + pytest.param( + lambda ctx: ctx.user, + id="basic username", + ), + pytest.param( + lambda ctx: ctx.punct_user, + id="'unsafe' characters are passed through correctly", + ), + ], +) +def test_oauth_validator_role(setup_validator, oauth_ctx, connect, user): + username = user(oauth_ctx) + + # Tell the validator to reflect the PGUSER as the authenticated identity. + setup_validator(reflect_role=True) + conn = connect() + + # Log in. Note that reflection ignores the bearer token. + begin_oauth_handshake(conn, oauth_ctx, user=username) + send_initial_response(conn, bearer=b"dontcare") + expect_handshake_success(conn) + + # Check the user identity. + pq3.send(conn, pq3.types.Query, query=b"SELECT system_user;") + resp = receive_until(conn, pq3.types.DataRow) + + row = resp.payload + expected = b"oauth:" + username.encode("utf-8") + assert row.columns == [expected] + + +@pytest.fixture +def odd_oauth_ctx(postgres_instance, oauth_ctx): + """ + Adds an HBA entry with messed up issuer/scope settings, to pin the server + behavior. + + TODO: these should really be rejected in the HBA rather than passed through + by the server. + """ + id = secrets.token_hex(4) + + class Context: + user = oauth_ctx.user + dbname = oauth_ctx.dbname + + # Both of these embedded double-quotes are invalid; they're prohibited + # in both URLs and OAuth scope identifiers. + issuer = oauth_ctx.issuer + '/"/' + scope = oauth_ctx.scope + ' quo"ted' + + ctx = Context() + hba_issuer = ctx.issuer.replace('"', '""') + hba_scope = ctx.scope.replace('"', '""') + hba_lines = [ + f'host {ctx.dbname} {ctx.user} samehost oauth issuer="{hba_issuer}" scope="{hba_scope}"\n', + ] + + if platform.system() == "Windows": + # XXX why is 'samehost' not behaving as expected on Windows? + for l in list(hba_lines): + hba_lines.append(l.replace("samehost", "::1/128")) + + host, port = postgres_instance + conn = psycopg2.connect(host=host, port=port) + conn.autocommit = True + + with contextlib.closing(conn): + c = conn.cursor() + + # Replace pg_hba. Note that it's already been replaced once by + # oauth_ctx, so use a different backup prefix in prepend_file(). + c.execute("SHOW hba_file;") + hba = c.fetchone()[0] + + with prepend_file(hba, hba_lines, suffix=".bak2"): + c.execute("SELECT pg_reload_conf();") + + yield ctx + + # Put things back the way they were. + c.execute("SELECT pg_reload_conf();") + + +def test_odd_server_response(odd_oauth_ctx, connect): + """ + Verifies that the server is correctly escaping the JSON in its failure + response. + """ + conn = connect() + begin_oauth_handshake(conn, odd_oauth_ctx, user=odd_oauth_ctx.user) + + # Send an empty auth initial response, which will force an authn failure. + send_initial_response(conn, auth=b"") + + expect_handshake_failure(conn, odd_oauth_ctx) diff --git a/src/test/python/server/test_server.py b/src/test/python/server/test_server.py new file mode 100644 index 0000000000000..02126dba79220 --- /dev/null +++ b/src/test/python/server/test_server.py @@ -0,0 +1,21 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import pq3 + + +def test_handshake(connect): + """Basic sanity check.""" + conn = connect() + + pq3.handshake(conn, user=pq3.pguser(), database=pq3.pgdatabase()) + + pq3.send(conn, pq3.types.Query, query=b"") + + resp = pq3.recv1(conn) + assert resp.type == pq3.types.EmptyQueryResponse + + resp = pq3.recv1(conn) + assert resp.type == pq3.types.ReadyForQuery diff --git a/src/test/python/test_internals.py b/src/test/python/test_internals.py new file mode 100644 index 0000000000000..dee4855fc0bee --- /dev/null +++ b/src/test/python/test_internals.py @@ -0,0 +1,138 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import io + +from pq3 import _DebugStream + + +def test_DebugStream_read(): + under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz") + out = io.StringIO() + + stream = _DebugStream(under, out) + + res = stream.read(5) + assert res == b"abcde" + + res = stream.read(16) + assert res == b"fghijklmnopqrstu" + + stream.flush_debug() + + res = stream.read() + assert res == b"vwxyz" + + stream.flush_debug() + + expected = ( + "< 0000:\t61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70\tabcdefghijklmnop\n" + "< 0010:\t71 72 73 74 75 \tqrstu\n" + "\n" + "< 0000:\t76 77 78 79 7a \tvwxyz\n" + "\n" + ) + assert out.getvalue() == expected + + +def test_DebugStream_write(): + under = io.BytesIO() + out = io.StringIO() + + stream = _DebugStream(under, out) + + stream.write(b"\x00\x01\x02") + stream.flush() + + assert under.getvalue() == b"\x00\x01\x02" + + stream.write(b"\xc0\xc1\xc2") + stream.flush() + + assert under.getvalue() == b"\x00\x01\x02\xc0\xc1\xc2" + + stream.flush_debug() + + expected = "> 0000:\t00 01 02 c0 c1 c2 \t......\n\n" + assert out.getvalue() == expected + + +def test_DebugStream_read_write(): + under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz") + out = io.StringIO() + stream = _DebugStream(under, out) + + res = stream.read(5) + assert res == b"abcde" + + stream.write(b"xxxxx") + stream.flush() + + assert under.getvalue() == b"abcdexxxxxklmnopqrstuvwxyz" + + res = stream.read(5) + assert res == b"klmno" + + stream.write(b"xxxxx") + stream.flush() + + assert under.getvalue() == b"abcdexxxxxklmnoxxxxxuvwxyz" + + stream.flush_debug() + + expected = ( + "< 0000:\t61 62 63 64 65 6b 6c 6d 6e 6f \tabcdeklmno\n" + "\n" + "> 0000:\t78 78 78 78 78 78 78 78 78 78 \txxxxxxxxxx\n" + "\n" + ) + assert out.getvalue() == expected + + +def test_DebugStream_end_packet(): + under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz") + out = io.StringIO() + stream = _DebugStream(under, out) + + stream.read(5) + stream.end_packet("read description", read=True, indent=" ") + + stream.write(b"xxxxx") + stream.flush() + stream.end_packet("write description", indent=" ") + + stream.read(5) + stream.write(b"xxxxx") + stream.flush() + stream.end_packet("read/write combo for read", read=True, indent=" ") + + stream.read(5) + stream.write(b"xxxxx") + stream.flush() + stream.end_packet("read/write combo for write", indent=" ") + + expected = ( + " < 0000:\t61 62 63 64 65 \tabcde\n" + "\n" + "< read description\n" + "\n" + "> write description\n" + "\n" + " > 0000:\t78 78 78 78 78 \txxxxx\n" + "\n" + " < 0000:\t6b 6c 6d 6e 6f \tklmno\n" + "\n" + " > 0000:\t78 78 78 78 78 \txxxxx\n" + "\n" + "< read/write combo for read\n" + "\n" + "> read/write combo for write\n" + "\n" + " < 0000:\t75 76 77 78 79 \tuvwxy\n" + "\n" + " > 0000:\t78 78 78 78 78 \txxxxx\n" + "\n" + ) + assert out.getvalue() == expected diff --git a/src/test/python/test_pq3.py b/src/test/python/test_pq3.py new file mode 100644 index 0000000000000..8f225f0ec912f --- /dev/null +++ b/src/test/python/test_pq3.py @@ -0,0 +1,574 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +import contextlib +import getpass +import io +import os +import platform +import struct +import sys + +import pytest +from construct import Container, PaddingError, StreamError, TerminatedError + +import pq3 + + +@pytest.mark.parametrize( + "raw,expected,extra", + [ + pytest.param( + b"\x00\x00\x00\x10\x00\x04\x00\x00abcdefgh", + Container(len=16, proto=0x40000, payload=b"abcdefgh"), + b"", + id="8-byte payload", + ), + pytest.param( + b"\x00\x00\x00\x08\x00\x04\x00\x00", + Container(len=8, proto=0x40000, payload=b""), + b"", + id="no payload", + ), + pytest.param( + b"\x00\x00\x00\x09\x00\x04\x00\x00abcde", + Container(len=9, proto=0x40000, payload=b"a"), + b"bcde", + id="1-byte payload and extra padding", + ), + pytest.param( + b"\x00\x00\x00\x0b\x00\x03\x00\x00hi\x00", + Container(len=11, proto=pq3.protocol(3, 0), payload=[b"hi"]), + b"", + id="implied parameter list when using proto version 3.0", + ), + ], +) +def test_Startup_parse(raw, expected, extra): + with io.BytesIO(raw) as stream: + actual = pq3.Startup.parse_stream(stream) + + assert actual == expected + assert stream.read() == extra + + +@pytest.mark.parametrize( + "packet,expected_bytes", + [ + pytest.param( + dict(), + b"\x00\x00\x00\x08\x00\x00\x00\x00", + id="nothing set", + ), + pytest.param( + dict(len=10, proto=0x12345678), + b"\x00\x00\x00\x0a\x12\x34\x56\x78\x00\x00", + id="len and proto set explicitly", + ), + pytest.param( + dict(proto=0x12345678), + b"\x00\x00\x00\x08\x12\x34\x56\x78", + id="implied len with no payload", + ), + pytest.param( + dict(proto=0x12345678, payload=b"abcd"), + b"\x00\x00\x00\x0c\x12\x34\x56\x78abcd", + id="implied len with payload", + ), + pytest.param( + dict(payload=[b""]), + b"\x00\x00\x00\x09\x00\x03\x00\x00\x00", + id="implied proto version 3 when sending parameters", + ), + pytest.param( + dict(payload=[b"hi", b""]), + b"\x00\x00\x00\x0c\x00\x03\x00\x00hi\x00\x00", + id="implied proto version 3 and len when sending more than one parameter", + ), + pytest.param( + dict(payload=dict(user="jsmith", database="postgres")), + b"\x00\x00\x00\x27\x00\x03\x00\x00user\x00jsmith\x00database\x00postgres\x00\x00", + id="auto-serialization of dict parameters", + ), + ], +) +def test_Startup_build(packet, expected_bytes): + actual = pq3.Startup.build(packet) + assert actual == expected_bytes + + +@pytest.mark.parametrize( + "raw,expected,extra", + [ + pytest.param( + b"*\x00\x00\x00\x08abcd", + dict(type=b"*", len=8, payload=b"abcd"), + b"", + id="4-byte payload", + ), + pytest.param( + b"*\x00\x00\x00\x04", + dict(type=b"*", len=4, payload=b""), + b"", + id="no payload", + ), + pytest.param( + b"*\x00\x00\x00\x05xabcd", + dict(type=b"*", len=5, payload=b"x"), + b"abcd", + id="1-byte payload with extra padding", + ), + pytest.param( + b"R\x00\x00\x00\x08\x00\x00\x00\x00", + dict( + type=pq3.types.AuthnRequest, + len=8, + payload=dict(type=pq3.authn.OK, body=None), + ), + b"", + id="AuthenticationOk", + ), + pytest.param( + b"R\x00\x00\x00\x12\x00\x00\x00\x0aEXTERNAL\x00\x00", + dict( + type=pq3.types.AuthnRequest, + len=18, + payload=dict(type=pq3.authn.SASL, body=[b"EXTERNAL", b""]), + ), + b"", + id="AuthenticationSASL", + ), + pytest.param( + b"R\x00\x00\x00\x0d\x00\x00\x00\x0b12345", + dict( + type=pq3.types.AuthnRequest, + len=13, + payload=dict(type=pq3.authn.SASLContinue, body=b"12345"), + ), + b"", + id="AuthenticationSASLContinue", + ), + pytest.param( + b"R\x00\x00\x00\x0d\x00\x00\x00\x0c12345", + dict( + type=pq3.types.AuthnRequest, + len=13, + payload=dict(type=pq3.authn.SASLFinal, body=b"12345"), + ), + b"", + id="AuthenticationSASLFinal", + ), + pytest.param( + b"p\x00\x00\x00\x0bhunter2", + dict( + type=pq3.types.PasswordMessage, + len=11, + payload=b"hunter2", + ), + b"", + id="PasswordMessage", + ), + pytest.param( + b"K\x00\x00\x00\x0c\x00\x00\x00\x00\x12\x34\x56\x78", + dict( + type=pq3.types.BackendKeyData, + len=12, + payload=dict(pid=0, key=0x12345678), + ), + b"", + id="BackendKeyData", + ), + pytest.param( + b"C\x00\x00\x00\x08SET\x00", + dict( + type=pq3.types.CommandComplete, + len=8, + payload=dict(tag=b"SET"), + ), + b"", + id="CommandComplete", + ), + pytest.param( + b"E\x00\x00\x00\x11Mbad!\x00Mdog!\x00\x00", + dict(type=b"E", len=17, payload=dict(fields=[b"Mbad!", b"Mdog!", b""])), + b"", + id="ErrorResponse", + ), + pytest.param( + b"S\x00\x00\x00\x08a\x00b\x00", + dict( + type=pq3.types.ParameterStatus, + len=8, + payload=dict(name=b"a", value=b"b"), + ), + b"", + id="ParameterStatus", + ), + pytest.param( + b"Z\x00\x00\x00\x05x", + dict(type=b"Z", len=5, payload=dict(status=b"x")), + b"", + id="ReadyForQuery", + ), + pytest.param( + b"Q\x00\x00\x00\x06!\x00", + dict(type=pq3.types.Query, len=6, payload=dict(query=b"!")), + b"", + id="Query", + ), + pytest.param( + b"D\x00\x00\x00\x0b\x00\x01\x00\x00\x00\x01!", + dict(type=pq3.types.DataRow, len=11, payload=dict(columns=[b"!"])), + b"", + id="DataRow", + ), + pytest.param( + b"D\x00\x00\x00\x06\x00\x00extra", + dict(type=pq3.types.DataRow, len=6, payload=dict(columns=[])), + b"extra", + id="DataRow with extra data", + ), + pytest.param( + b"I\x00\x00\x00\x04", + dict(type=pq3.types.EmptyQueryResponse, len=4, payload=None), + b"", + id="EmptyQueryResponse", + ), + pytest.param( + b"I\x00\x00\x00\x04\xff", + dict(type=b"I", len=4, payload=None), + b"\xff", + id="EmptyQueryResponse with extra bytes", + ), + pytest.param( + b"X\x00\x00\x00\x04", + dict(type=pq3.types.Terminate, len=4, payload=None), + b"", + id="Terminate", + ), + ], +) +def test_Pq3_parse(raw, expected, extra): + with io.BytesIO(raw) as stream: + actual = pq3.Pq3.parse_stream(stream) + + assert actual == expected + assert stream.read() == extra + + +@pytest.mark.parametrize( + "fields,expected", + [ + pytest.param( + dict(type=b"*", len=5), + b"*\x00\x00\x00\x05", + id="type and len set explicitly", + ), + pytest.param( + dict(type=b"*"), + b"*\x00\x00\x00\x04", + id="implied len with no payload", + ), + pytest.param( + dict(type=b"*", payload=b"1234"), + b"*\x00\x00\x00\x081234", + id="implied len with payload", + ), + pytest.param( + dict(type=b"*", len=12, payload=b"1234"), + b"*\x00\x00\x00\x0c1234", + id="overridden len (payload underflow)", + ), + pytest.param( + dict(type=b"*", len=5, payload=b"1234"), + b"*\x00\x00\x00\x051234", + id="overridden len (payload overflow)", + ), + pytest.param( + dict(type=pq3.types.AuthnRequest, payload=dict(type=pq3.authn.OK)), + b"R\x00\x00\x00\x08\x00\x00\x00\x00", + id="implied len/type for AuthenticationOK", + ), + pytest.param( + dict( + type=pq3.types.AuthnRequest, + payload=dict( + type=pq3.authn.SASL, + body=[b"SCRAM-SHA-256-PLUS", b"SCRAM-SHA-256", b""], + ), + ), + b"R\x00\x00\x00\x2a\x00\x00\x00\x0aSCRAM-SHA-256-PLUS\x00SCRAM-SHA-256\x00\x00", + id="implied len/type for AuthenticationSASL", + ), + pytest.param( + dict( + type=pq3.types.AuthnRequest, + payload=dict(type=pq3.authn.SASLContinue, body=b"12345"), + ), + b"R\x00\x00\x00\x0d\x00\x00\x00\x0b12345", + id="implied len/type for AuthenticationSASLContinue", + ), + pytest.param( + dict( + type=pq3.types.AuthnRequest, + payload=dict(type=pq3.authn.SASLFinal, body=b"12345"), + ), + b"R\x00\x00\x00\x0d\x00\x00\x00\x0c12345", + id="implied len/type for AuthenticationSASLFinal", + ), + pytest.param( + dict( + type=pq3.types.PasswordMessage, + payload=b"hunter2", + ), + b"p\x00\x00\x00\x0bhunter2", + id="implied len/type for PasswordMessage", + ), + pytest.param( + dict(type=pq3.types.BackendKeyData, payload=dict(pid=1, key=7)), + b"K\x00\x00\x00\x0c\x00\x00\x00\x01\x00\x00\x00\x07", + id="implied len/type for BackendKeyData", + ), + pytest.param( + dict(type=pq3.types.CommandComplete, payload=dict(tag=b"SET")), + b"C\x00\x00\x00\x08SET\x00", + id="implied len/type for CommandComplete", + ), + pytest.param( + dict(type=pq3.types.ErrorResponse, payload=dict(fields=[b"error", b""])), + b"E\x00\x00\x00\x0berror\x00\x00", + id="implied len/type for ErrorResponse", + ), + pytest.param( + dict(type=pq3.types.ParameterStatus, payload=dict(name=b"a", value=b"b")), + b"S\x00\x00\x00\x08a\x00b\x00", + id="implied len/type for ParameterStatus", + ), + pytest.param( + dict(type=pq3.types.ReadyForQuery, payload=dict(status=b"I")), + b"Z\x00\x00\x00\x05I", + id="implied len/type for ReadyForQuery", + ), + pytest.param( + dict(type=pq3.types.Query, payload=dict(query=b"SELECT 1;")), + b"Q\x00\x00\x00\x0eSELECT 1;\x00", + id="implied len/type for Query", + ), + pytest.param( + dict(type=pq3.types.DataRow, payload=dict(columns=[b"abcd"])), + b"D\x00\x00\x00\x0e\x00\x01\x00\x00\x00\x04abcd", + id="implied len/type for DataRow", + ), + pytest.param( + dict(type=pq3.types.EmptyQueryResponse), + b"I\x00\x00\x00\x04", + id="implied len for EmptyQueryResponse", + ), + pytest.param( + dict(type=pq3.types.Terminate), + b"X\x00\x00\x00\x04", + id="implied len for Terminate", + ), + ], +) +def test_Pq3_build(fields, expected): + actual = pq3.Pq3.build(fields) + assert actual == expected + + +@pytest.mark.parametrize( + "raw,expected,extra", + [ + pytest.param( + b"\x00\x00", + dict(columns=[]), + b"", + id="no columns", + ), + pytest.param( + b"\x00\x01\x00\x00\x00\x04abcd", + dict(columns=[b"abcd"]), + b"", + id="one column", + ), + pytest.param( + b"\x00\x02\x00\x00\x00\x04abcd\x00\x00\x00\x01x", + dict(columns=[b"abcd", b"x"]), + b"", + id="multiple columns", + ), + pytest.param( + b"\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01x", + dict(columns=[b"", b"x"]), + b"", + id="empty column value", + ), + pytest.param( + b"\x00\x02\xff\xff\xff\xff\xff\xff\xff\xff", + dict(columns=[None, None]), + b"", + id="null columns", + ), + ], +) +def test_DataRow_parse(raw, expected, extra): + pkt = b"D" + struct.pack("!i", len(raw) + 4) + raw + with io.BytesIO(pkt) as stream: + actual = pq3.Pq3.parse_stream(stream) + + assert actual.type == pq3.types.DataRow + assert actual.payload == expected + assert stream.read() == extra + + +@pytest.mark.parametrize( + "fields,expected", + [ + pytest.param( + dict(), + b"\x00\x00", + id="no columns", + ), + pytest.param( + dict(columns=[None, None]), + b"\x00\x02\xff\xff\xff\xff\xff\xff\xff\xff", + id="null columns", + ), + ], +) +def test_DataRow_build(fields, expected): + actual = pq3.Pq3.build(dict(type=pq3.types.DataRow, payload=fields)) + + expected = b"D" + struct.pack("!i", len(expected) + 4) + expected + assert actual == expected + + +@pytest.mark.parametrize( + "raw,expected,exception", + [ + pytest.param( + b"EXTERNAL\x00\xff\xff\xff\xff", + dict(name=b"EXTERNAL", len=-1, data=None), + None, + id="no initial response", + ), + pytest.param( + b"EXTERNAL\x00\x00\x00\x00\x02me", + dict(name=b"EXTERNAL", len=2, data=b"me"), + None, + id="initial response", + ), + pytest.param( + b"EXTERNAL\x00\x00\x00\x00\x02meextra", + None, + TerminatedError, + id="extra data", + ), + pytest.param( + b"EXTERNAL\x00\x00\x00\x00\xffme", + None, + StreamError, + id="underflow", + ), + ], +) +def test_SASLInitialResponse_parse(raw, expected, exception): + ctx = contextlib.nullcontext() + if exception: + ctx = pytest.raises(exception) + + with ctx: + actual = pq3.SASLInitialResponse.parse(raw) + assert actual == expected + + +@pytest.mark.parametrize( + "fields,expected", + [ + pytest.param( + dict(name=b"EXTERNAL"), + b"EXTERNAL\x00\xff\xff\xff\xff", + id="no initial response", + ), + pytest.param( + dict(name=b"EXTERNAL", data=None), + b"EXTERNAL\x00\xff\xff\xff\xff", + id="no initial response (explicit None)", + ), + pytest.param( + dict(name=b"EXTERNAL", data=b""), + b"EXTERNAL\x00\x00\x00\x00\x00", + id="empty response", + ), + pytest.param( + dict(name=b"EXTERNAL", data=b"me@example.com"), + b"EXTERNAL\x00\x00\x00\x00\x0eme@example.com", + id="initial response", + ), + pytest.param( + dict(name=b"EXTERNAL", len=2, data=b"me@example.com"), + b"EXTERNAL\x00\x00\x00\x00\x02me@example.com", + id="data overflow", + ), + pytest.param( + dict(name=b"EXTERNAL", len=14, data=b"me"), + b"EXTERNAL\x00\x00\x00\x00\x0eme", + id="data underflow", + ), + ], +) +def test_SASLInitialResponse_build(fields, expected): + actual = pq3.SASLInitialResponse.build(fields) + assert actual == expected + + +@pytest.mark.parametrize( + "version,expected_bytes", + [ + pytest.param((3, 0), b"\x00\x03\x00\x00", id="version 3"), + pytest.param((1234, 5679), b"\x04\xd2\x16\x2f", id="SSLRequest"), + ], +) +def test_protocol(version, expected_bytes): + # Make sure the integer returned by protocol is correctly serialized on the + # wire. + assert struct.pack("!i", pq3.protocol(*version)) == expected_bytes + + +@pytest.mark.parametrize( + "envvar,func,expected", + [ + ("PGHOST", pq3.pghost, "localhost"), + ("PGPORT", pq3.pgport, 5432), + ( + "PGUSER", + pq3.pguser, + os.getlogin() if platform.system() == "Windows" else getpass.getuser(), + ), + ("PGDATABASE", pq3.pgdatabase, "postgres"), + ], +) +def test_env_defaults(monkeypatch, envvar, func, expected): + monkeypatch.delenv(envvar, raising=False) + + actual = func() + assert actual == expected + + +@pytest.mark.parametrize( + "envvars,func,expected", + [ + (dict(PGHOST="otherhost"), pq3.pghost, "otherhost"), + (dict(PGPORT="6789"), pq3.pgport, 6789), + (dict(PGUSER="postgres"), pq3.pguser, "postgres"), + (dict(PGDATABASE="template1"), pq3.pgdatabase, "template1"), + ], +) +def test_env(monkeypatch, envvars, func, expected): + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + actual = func() + assert actual == expected diff --git a/src/test/python/tls.py b/src/test/python/tls.py new file mode 100644 index 0000000000000..075c02c1ca6ea --- /dev/null +++ b/src/test/python/tls.py @@ -0,0 +1,195 @@ +# +# Copyright 2021 VMware, Inc. +# SPDX-License-Identifier: PostgreSQL +# + +from construct import * + +# +# TLS 1.3 +# +# Most of the types below are transcribed from RFC 8446: +# +# https://tools.ietf.org/html/rfc8446 +# + + +def _Vector(size_field, element): + return Prefixed(size_field, GreedyRange(element)) + + +# Alerts + +AlertLevel = Enum( + Byte, + warning=1, + fatal=2, +) + +AlertDescription = Enum( + Byte, + close_notify=0, + unexpected_message=10, + bad_record_mac=20, + decryption_failed_RESERVED=21, + record_overflow=22, + decompression_failure=30, + handshake_failure=40, + no_certificate_RESERVED=41, + bad_certificate=42, + unsupported_certificate=43, + certificate_revoked=44, + certificate_expired=45, + certificate_unknown=46, + illegal_parameter=47, + unknown_ca=48, + access_denied=49, + decode_error=50, + decrypt_error=51, + export_restriction_RESERVED=60, + protocol_version=70, + insufficient_security=71, + internal_error=80, + user_canceled=90, + no_renegotiation=100, + unsupported_extension=110, +) + +Alert = Struct( + "level" / AlertLevel, + "description" / AlertDescription, +) + + +# Extensions + +ExtensionType = Enum( + Int16ub, + server_name=0, + max_fragment_length=1, + status_request=5, + supported_groups=10, + signature_algorithms=13, + use_srtp=14, + heartbeat=15, + application_layer_protocol_negotiation=16, + signed_certificate_timestamp=18, + client_certificate_type=19, + server_certificate_type=20, + padding=21, + pre_shared_key=41, + early_data=42, + supported_versions=43, + cookie=44, + psk_key_exchange_modes=45, + certificate_authorities=47, + oid_filters=48, + post_handshake_auth=49, + signature_algorithms_cert=50, + key_share=51, +) + +Extension = Struct( + "extension_type" / ExtensionType, + "extension_data" / Prefixed(Int16ub, GreedyBytes), +) + + +# ClientHello + + +class _CipherSuiteAdapter(Adapter): + class _hextuple(tuple): + def __repr__(self): + return f"(0x{self[0]:02X}, 0x{self[1]:02X})" + + def _encode(self, obj, context, path): + return bytes(obj) + + def _decode(self, obj, context, path): + assert len(obj) == 2 + return self._hextuple(obj) + + +ProtocolVersion = Hex(Int16ub) + +Random = Hex(Bytes(32)) + +CipherSuite = _CipherSuiteAdapter(Byte[2]) + +ClientHello = Struct( + "legacy_version" / ProtocolVersion, + "random" / Random, + "legacy_session_id" / Prefixed(Byte, Hex(GreedyBytes)), + "cipher_suites" / _Vector(Int16ub, CipherSuite), + "legacy_compression_methods" / Prefixed(Byte, GreedyBytes), + "extensions" / _Vector(Int16ub, Extension), +) + +# ServerHello + +ServerHello = Struct( + "legacy_version" / ProtocolVersion, + "random" / Random, + "legacy_session_id_echo" / Prefixed(Byte, Hex(GreedyBytes)), + "cipher_suite" / CipherSuite, + "legacy_compression_method" / Hex(Byte), + "extensions" / _Vector(Int16ub, Extension), +) + +# Handshake + +HandshakeType = Enum( + Byte, + client_hello=1, + server_hello=2, + new_session_ticket=4, + end_of_early_data=5, + encrypted_extensions=8, + certificate=11, + certificate_request=13, + certificate_verify=15, + finished=20, + key_update=24, + message_hash=254, +) + +Handshake = Struct( + "msg_type" / HandshakeType, + "length" / Int24ub, + "payload" + / Switch( + this.msg_type, + { + HandshakeType.client_hello: ClientHello, + HandshakeType.server_hello: ServerHello, + # HandshakeType.end_of_early_data: EndOfEarlyData, + # HandshakeType.encrypted_extensions: EncryptedExtensions, + # HandshakeType.certificate_request: CertificateRequest, + # HandshakeType.certificate: Certificate, + # HandshakeType.certificate_verify: CertificateVerify, + # HandshakeType.finished: Finished, + # HandshakeType.new_session_ticket: NewSessionTicket, + # HandshakeType.key_update: KeyUpdate, + }, + default=FixedSized(this.length, GreedyBytes), + ), +) + +# Records + +ContentType = Enum( + Byte, + invalid=0, + change_cipher_spec=20, + alert=21, + handshake=22, + application_data=23, +) + +Plaintext = Struct( + "type" / ContentType, + "legacy_record_version" / ProtocolVersion, + "length" / Int16ub, + "fragment" / FixedSized(this.length, GreedyBytes), +) diff --git a/src/tools/make_venv b/src/tools/make_venv new file mode 100755 index 0000000000000..879b639938877 --- /dev/null +++ b/src/tools/make_venv @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +import argparse +import subprocess +import os +import platform +import sys + +parser = argparse.ArgumentParser() + +parser.add_argument('--requirements', help='path to pip requirements file', type=str) +parser.add_argument('--privatedir', help='private directory for target', type=str) +parser.add_argument('venv_path', help='desired venv location') + +args = parser.parse_args() + +# Decide whether or not to capture stdout into a log file. We only do this if +# we've been given our own private directory. +# +# FIXME Unfortunately this interferes with debugging on Cirrus, because the +# private directory isn't uploaded in the sanity check's artifacts. When we +# don't capture the log file, it gets spammed to stdout during build... Is there +# a way to push this into the meson-log somehow? For now, the capture +# implementation is commented out. +logfile = None + +if args.privatedir: + if not os.path.isdir(args.privatedir): + os.mkdir(args.privatedir) + + # FIXME see above comment + # logpath = os.path.join(args.privatedir, 'stdout.txt') + # logfile = open(logpath, 'w') + +def run(*args): + kwargs = dict(check=True) + if logfile: + kwargs.update(stdout=logfile) + + subprocess.run(args, **kwargs) + +# Create the virtualenv first. +run(sys.executable, '-m', 'venv', '--system-site-packages', args.venv_path) + +# Update pip next. This helps avoid old pip bugs; the version inside system +# Pythons tends to be pretty out of date. +bindir = 'Scripts' if platform.system() == 'Windows' else 'bin' +python = os.path.join(args.venv_path, bindir, 'python3') +run(python, '-m', 'pip', 'install', '-U', 'pip') + +# Finally, install the test's requirements. We need pytest and pytest-tap, no +# matter what the test needs. +pip = os.path.join(args.venv_path, bindir, 'pip') +run(pip, 'install', 'pytest', 'pytest-tap') +if args.requirements: + run(pip, 'install', '-r', args.requirements) diff --git a/src/tools/testwrap b/src/tools/testwrap index 02f1951ad7e94..f9939b03109c5 100755 --- a/src/tools/testwrap +++ b/src/tools/testwrap @@ -14,6 +14,7 @@ parser.add_argument('--testgroup', help='test group', type=str) parser.add_argument('--testname', help='test name', type=str) parser.add_argument('--skip', help='skip test (with reason)', type=str) parser.add_argument('--pg-test-extra', help='extra tests', type=str) +parser.add_argument('--skip-without-extra', help='skip if PG_TEST_EXTRA is missing this arg', type=str) parser.add_argument('test_command', nargs='*') args = parser.parse_args() @@ -29,6 +30,12 @@ if args.skip is not None: print('1..0 # Skipped: ' + args.skip) sys.exit(0) +if args.skip_without_extra is not None: + extras = os.environ.get("PG_TEST_EXTRA", args.pg_test_extra) + if extras is None or args.skip_without_extra not in extras.split(): + print(f'1..0 # Skipped: PG_TEST_EXTRA does not contain "{args.skip_without_extra}"') + sys.exit(0) + if os.path.exists(testdir) and os.path.isdir(testdir): shutil.rmtree(testdir) os.makedirs(testdir)