Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit aad3149

Browse files
authored
Pull latest C++ x10 changes (#786)
1 parent 5235938 commit aad3149

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+676
-242
lines changed

Diff for: Sources/x10/xla_client/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cc_library(
3030
"tf_logging.cc",
3131
"thread_pool.cc",
3232
"triggered_task.cc",
33+
"util.cc",
3334
"xla_util.cc",
3435
"xrt_computation_client.cc",
3536
"xrt_local_service.cc",
@@ -91,6 +92,7 @@ cc_library(
9192
"//tensorflow/core/kernels:data_flow",
9293
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
9394
"@com_google_absl//absl/memory",
95+
"@com_google_absl//absl/numeric:int128",
9496
"@com_google_absl//absl/strings",
9597
"@com_google_absl//absl/types:optional",
9698
"@com_google_absl//absl/types:span",

Diff for: Sources/x10/xla_client/types.h

+3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
#include <cmath>
2121
#include <vector>
2222

23+
#include "absl/numeric/int128.h"
2324
#include "absl/types/optional.h"
2425
#include "tensorflow/compiler/xla/types.h"
2526

2627
namespace xla {
2728

29+
using hash_t = absl::uint128;
30+
2831
struct Percentile {
2932
enum class UnitOfMeaure {
3033
kNumber,

Diff for: Sources/x10/xla_client/util.cc

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "tensorflow/compiler/xla/xla_client/util.h"
2+
3+
#include <sstream>
4+
5+
namespace xla {
6+
namespace util {
7+
namespace {
8+
9+
hash_t LoadHash(const uint8** data, const uint8* top) {
10+
std::ptrdiff_t size = top - (*data);
11+
if (size >= sizeof(hash_t)) {
12+
hash_t v;
13+
std::memcpy(&v, *data, sizeof(v));
14+
*data += sizeof(hash_t);
15+
return v;
16+
}
17+
18+
union {
19+
hash_t h;
20+
uint8 b[sizeof(hash_t)];
21+
} uval;
22+
uval.h = 0;
23+
std::memcpy(uval.b, *data, size);
24+
*data += size;
25+
return uval.h;
26+
}
27+
28+
} // namespace
29+
30+
hash_t HashBlock(const void* data, size_t n, const hash_t& seed) {
31+
const hash_t m = 0xc6a4a7935bd1e995;
32+
const int r = 47;
33+
34+
const uint8* u8_data = reinterpret_cast<const uint8*>(data);
35+
const uint8* top = u8_data + n;
36+
hash_t h = seed ^ (n * m);
37+
while (u8_data < top) {
38+
hash_t k = LoadHash(&u8_data, top);
39+
k *= m;
40+
k ^= k >> r;
41+
k *= m;
42+
43+
h ^= k;
44+
h *= m;
45+
}
46+
h ^= h >> r;
47+
h *= m;
48+
h ^= h >> r;
49+
return h;
50+
}
51+
52+
hash_t DataHash(const void* data, size_t size) {
53+
return HashBlock(data, size, 0xc2b2ae3d27d4eb4f);
54+
}
55+
56+
size_t StdDataHash(const void* data, size_t size) {
57+
return HashReduce(DataHash(data, size));
58+
}
59+
60+
size_t StdHashCombine(uintmax_t a, uintmax_t b) {
61+
return a ^
62+
(b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2));
63+
}
64+
65+
hash_t HashCombine(const hash_t& a, const hash_t& b) {
66+
static const hash_t kb = absl::MakeUint128(101, 0x27d4eb2f165667c5);
67+
return a ^ (b * kb + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2));
68+
}
69+
70+
size_t HashReduce(const hash_t& a) {
71+
return StdHashCombine(absl::Uint128Low64(a), absl::Uint128High64(a));
72+
}
73+
74+
std::string HexHash(const hash_t& a) {
75+
std::stringstream ss;
76+
ss << std::hex << a;
77+
return ss.str();
78+
}
79+
80+
} // namespace util
81+
} // namespace xla

Diff for: Sources/x10/xla_client/util.h

+36-21
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,47 @@
1818
#define X10_XLA_CLIENT_UTIL_H_
1919

2020
#include <algorithm>
21+
#include <cstdint>
2122
#include <cstring>
2223
#include <exception>
2324
#include <functional>
2425
#include <memory>
2526
#include <numeric>
2627
#include <set>
28+
#include <string>
2729
#include <type_traits>
2830
#include <vector>
2931

3032
#include "absl/types/optional.h"
3133
#include "absl/types/span.h"
3234
#include "tensorflow/compiler/xla/status.h"
35+
#include "tensorflow/compiler/xla/xla_client/types.h"
3336
#include "tensorflow/core/lib/core/errors.h"
3437
#include "tensorflow/core/lib/hash/hash.h"
3538

3639
namespace xla {
3740
namespace util {
3841

42+
hash_t HashBlock(const void* data, size_t n, const hash_t& seed);
43+
44+
hash_t DataHash(const void* data, size_t size);
45+
46+
size_t StdDataHash(const void* data, size_t size);
47+
48+
size_t StdHashCombine(uintmax_t a, uintmax_t b);
49+
50+
hash_t HashCombine(const hash_t& a, const hash_t& b);
51+
52+
size_t HashReduce(const hash_t& a);
53+
54+
std::string HexHash(const hash_t& a);
55+
56+
struct HashReducer {
57+
size_t operator()(const xla::hash_t& value) const {
58+
return HashReduce(value);
59+
}
60+
};
61+
3962
template <typename F>
4063
Status CheckedCall(const F& fn) {
4164
fn();
@@ -259,65 +282,57 @@ T Multiply(const S& input) {
259282
std::multiplies<T>());
260283
}
261284

262-
static inline size_t DataHash(const void* data, size_t size) {
263-
return tensorflow::Hash64(reinterpret_cast<const char*>(data), size,
264-
0xc2b2ae3d27d4eb4f);
265-
}
266-
267-
static inline size_t StringHash(const char* data) {
285+
static inline hash_t StringHash(const char* data) {
268286
return DataHash(data, std::strlen(data));
269287
}
270288

271-
static inline size_t HashCombine(size_t a, size_t b) {
272-
return a ^
273-
(b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2));
274-
}
275-
276289
template <typename T, typename std::enable_if<
277290
std::is_arithmetic<T>::value>::type* = nullptr>
278-
size_t Hash(const T& value) {
291+
hash_t Hash(const T& value) {
279292
return DataHash(&value, sizeof(value));
280293
}
281294

282-
static inline size_t Hash(const std::string& value) {
295+
static inline hash_t Hash(const std::string& value) {
283296
return DataHash(value.data(), value.size());
284297
}
285298

286299
// Forward declare to allow hashes of vectors of vectors to work.
287300
template <typename T>
288-
size_t ContainerHash(const T& values);
301+
hash_t ContainerHash(const T& values);
289302

290303
template <typename T>
291-
size_t Hash(absl::Span<const T> values) {
304+
hash_t Hash(absl::Span<const T> values) {
292305
return ContainerHash(values);
293306
}
294307

295308
template <typename T>
296-
size_t Hash(const std::vector<T>& values) {
309+
hash_t Hash(const std::vector<T>& values) {
297310
return ContainerHash(values);
298311
}
299312

300313
template <typename T>
301-
size_t Hash(const std::set<T>& values) {
314+
hash_t Hash(const std::set<T>& values) {
302315
return ContainerHash(values);
303316
}
304317

318+
static inline hash_t Hash(const hash_t& value) { return value; }
319+
305320
template <typename T>
306-
size_t ContainerHash(const T& values) {
307-
size_t h = 0x85ebca77c2b2ae63;
321+
hash_t ContainerHash(const T& values) {
322+
hash_t h = 0x85ebca77c2b2ae63;
308323
for (auto& value : values) {
309324
h = HashCombine(h, Hash(value));
310325
}
311326
return h;
312327
}
313328

314329
template <typename T = void>
315-
size_t MHash() {
330+
hash_t MHash() {
316331
return 0x165667b19e3779f9;
317332
}
318333

319334
template <typename T, typename... Targs>
320-
size_t MHash(T value, Targs... Fargs) {
335+
hash_t MHash(T value, Targs... Fargs) {
321336
return HashCombine(Hash(value), MHash(Fargs...));
322337
}
323338

Diff for: Sources/x10/xla_client/xla_util.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace xla {
3232
namespace util {
3333
namespace {
3434

35-
size_t SingleShapeHash(const Shape& shape, size_t seed) {
35+
hash_t SingleShapeHash(const Shape& shape, hash_t seed) {
3636
for (auto dim : shape.layout().minor_to_major()) {
3737
seed = HashCombine(seed, dim);
3838
}
@@ -98,8 +98,8 @@ void CheckComputationStatus(
9898
}
9999
}
100100

101-
size_t ShapeHash(const Shape& shape) {
102-
size_t hash = 0xa5d2d6916;
101+
hash_t ShapeHash(const Shape& shape) {
102+
hash_t hash = 0xa5d2d6916;
103103
ShapeUtil::ForEachSubshape(shape,
104104
[&](const Shape& subshape, const ShapeIndex&) {
105105
hash = SingleShapeHash(subshape, hash);

Diff for: Sources/x10/xla_client/xla_util.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorflow/compiler/xla/service/hlo.pb.h"
2525
#include "tensorflow/compiler/xla/service/hlo_module.h"
2626
#include "tensorflow/compiler/xla/status_macros.h"
27+
#include "tensorflow/compiler/xla/xla_client/types.h"
2728

2829
namespace xla {
2930
namespace util {
@@ -46,7 +47,7 @@ void CheckComputationStatus(
4647
const Status& status, absl::Span<const XlaComputation* const> computations,
4748
absl::Span<const Shape* const> output_shapes);
4849

49-
size_t ShapeHash(const Shape& shape);
50+
hash_t ShapeHash(const Shape& shape);
5051

5152
} // namespace util
5253
} // namespace xla

Diff for: Sources/x10/xla_client/xrt_computation_client.cc

+18-14
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TensorAllocator : public tensorflow::Allocator {
4949
struct AllocKey {
5050
struct Hash {
5151
size_t operator()(const AllocKey& hk) const {
52-
return util::HashCombine(hk.alignment, hk.num_bytes);
52+
return util::StdHashCombine(hk.alignment, hk.num_bytes);
5353
}
5454
};
5555

@@ -544,7 +544,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
544544
size_t base_index = partitions[i];
545545
size_t length = (i + 1 < partitions.size())
546546
? partitions[i + 1] - base_index
547-
: partitions.size() - base_index;
547+
: tensors.size() - base_index;
548548
auto partitions_results =
549549
TransferToServerInternal(tensors.subspan(base_index, length));
550550
for (size_t r = 0; r < length; ++r) {
@@ -1913,30 +1913,34 @@ const XrtSession::CachedNode& XrtComputationClient::GetSubTupleNode(
19131913
tensorflow::DataType XrtComputationClient::XlaTypeToDataType(
19141914
PrimitiveType dtype) {
19151915
switch (dtype) {
1916-
case PRED:
1916+
case PrimitiveType::PRED:
19171917
return tensorflow::DT_BOOL;
1918-
case S8:
1918+
case PrimitiveType::S8:
19191919
return tensorflow::DT_INT8;
1920-
case U8:
1920+
case PrimitiveType::U8:
19211921
return tensorflow::DT_UINT8;
1922-
case S16:
1922+
case PrimitiveType::S16:
19231923
return tensorflow::DT_INT16;
1924-
case U16:
1924+
case PrimitiveType::U16:
19251925
return tensorflow::DT_UINT16;
1926-
case S32:
1926+
case PrimitiveType::S32:
19271927
return tensorflow::DT_INT32;
1928-
case U32:
1928+
case PrimitiveType::U32:
19291929
return tensorflow::DT_UINT32;
1930-
case S64:
1930+
case PrimitiveType::S64:
19311931
return tensorflow::DT_INT64;
1932-
case U64:
1932+
case PrimitiveType::U64:
19331933
return tensorflow::DT_UINT64;
1934-
case F32:
1934+
case PrimitiveType::F32:
19351935
return tensorflow::DT_FLOAT;
1936-
case F64:
1936+
case PrimitiveType::F64:
19371937
return tensorflow::DT_DOUBLE;
1938-
case BF16:
1938+
case PrimitiveType::BF16:
19391939
return tensorflow::DT_BFLOAT16;
1940+
case PrimitiveType::C64:
1941+
return tensorflow::DT_COMPLEX64;
1942+
case PrimitiveType::C128:
1943+
return tensorflow::DT_COMPLEX128;
19401944
default:
19411945
break;
19421946
}

Diff for: Sources/x10/xla_client/xrt_computation_client.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,9 @@ class XrtComputationClient : public ComputationClient {
208208
struct Hash {
209209
size_t operator()(const CompilationCacheKey& entry) const {
210210
util::PartialHasher<std::string, 4096> hasher;
211-
return tensorflow::Hash64(entry.domain.data(), entry.domain.size(),
212-
hasher(entry.serialized_computation));
211+
hash_t h = util::DataHash(entry.domain.data(), entry.domain.size());
212+
return util::HashReduce(
213+
util::HashCombine(h, hasher(entry.serialized_computation)));
213214
}
214215
};
215216

Diff for: Sources/x10/xla_tensor/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ tf_cc_shared_object(
6969
name = "libx10.so",
7070
linkopts = [
7171
"-z defs",
72+
"-s",
7273
"-Wl,--version-script,$(location :tf_version_script.lds)",
7374
],
7475
visibility = ["//visibility:public"],

0 commit comments

Comments
 (0)