Skip to content

Commit a7814f8

Browse files
reedwmtensorflower-gardener
authored andcommitted
Fix incorrect unknown flags error message.
Before, if both unknown and known flags were in XLA_FLAGS, the error would list the unknown flags and also incorrectly some known flags as being unknown. Now, only the unknown flags are shown. The issue was tsl::Flags::Parse would set its argc and argv parameters to only have the unknown flags. The caller in parse_flags_from_env.cc would pass a pointer to the first element of a vector<char*> for the argv parameter. But then the caller would use the size of the vector<char*> when printing unknown flags, which was not mutated by tsl::Flags::Parse, instead of argc, which was mutated. PiperOrigin-RevId: 723714502
1 parent 024065e commit a7814f8

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

third_party/xla/xla/parse_flags_from_env.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ static void DieIfEnvHasUnknownFlagsLeft(absl::string_view envvar) {
227227
SetArgvFromEnv(envvar, env_argv);
228228

229229
if (env_argv->argc != 1) {
230+
auto unknown_flags = absl::MakeSpan(env_argv->argv).first(env_argv->argc);
230231
// Skip the first argv, which is the fake argv[0].
231-
auto unknown_flags = absl::MakeSpan(env_argv->argv);
232232
unknown_flags.remove_prefix(1);
233233
LOG(QFATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "")
234234
<< " in " << envvar << ": "

third_party/xla/xla/parse_flags_from_env_test.cc

+34
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include <string>
2626
#include <vector>
2727

28+
#include <gmock/gmock.h>
2829
#include <gtest/gtest.h>
2930
#include "absl/log/check.h"
3031
#include "absl/strings/str_format.h"
@@ -210,17 +211,50 @@ TEST(ParseFlagsFromEnv, ErrorOutOnUnknownFlag) {
210211
EXPECT_NE(child_status, 0);
211212
}
212213

214+
TEST(ParseFlagsFromEnv, UknownFlagErrorMessage) {
215+
const char* env =
216+
"--unknown_flag_1=value --int_flag=3 --unknown_flag_2=value "
217+
"--float_flag=3.0";
218+
219+
if (env == nullptr) {
220+
// Might be set from previous tests.
221+
tsl::unsetenv("TF_XLA_FLAGS");
222+
} else {
223+
tsl::setenv("TF_XLA_FLAGS", env, /*overwrite=*/true);
224+
}
225+
tsl::SubProcess child;
226+
std::vector<std::string> argv;
227+
argv.push_back(binary_name);
228+
argv.push_back("--recursing");
229+
child.SetProgram(binary_name, argv);
230+
child.SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
231+
child.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
232+
EXPECT_TRUE(child.Start());
233+
std::string stdout_str;
234+
std::string stderr_str;
235+
236+
int child_status = child.Communicate(nullptr, &stdout_str, &stderr_str);
237+
EXPECT_NE(child_status, 0);
238+
239+
EXPECT_THAT(
240+
stderr_str,
241+
::testing::EndsWith("Unknown flags in TF_XLA_FLAGS: "
242+
"--unknown_flag_1=value --unknown_flag_2=value\n"));
243+
}
244+
213245
} // namespace xla
214246

215247
int main(int argc, char* argv[]) {
216248
// Save name of binary so that it may invoke itself.
217249
xla::binary_name = argv[0];
218250
bool recursing = false;
219251
int32_t int_flag = 1;
252+
float float_flag = 1.;
220253
const std::vector<tsl::Flag> flag_list = {
221254
tsl::Flag("recursing", &recursing,
222255
"Whether the binary is being invoked recursively."),
223256
tsl::Flag("int_flag", &int_flag, "An integer flag to test with"),
257+
tsl::Flag("float_flag", &float_flag, "A float flag to test with"),
224258
};
225259
std::string usage = tsl::Flags::Usage(argv[0], flag_list);
226260
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);

0 commit comments

Comments
 (0)