Skip to content

Commit 2b7ed60

Browse files
fix tf version
1 parent dafa300 commit 2b7ed60

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

requirements/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
numpy
22
scipy
3-
tensorflow
3+
tensorflow<2.16 # tf 2.16 with integration of keras 3 seems a disaster...
44
tensornetwork-ng
55
graphviz
66
jax

tensorcircuit/applications/ai/ensemble.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Any, List, Optional
6+
from copy import deepcopy
67
import tensorflow as tf
78
import numpy as np
89

@@ -53,10 +54,12 @@ def train(self, **kwargs: kwargus) -> None:
5354

5455
def compile(self, **kwargs: kwargus) -> None:
5556
self.permit_train = True
57+
5658
for i in range(self.count):
5759
if not self.model_trained[i]:
5860
dict_kwargs = kwargs.copy()
5961
# TODO(@refraction-ray): still not compatible with new optimizer
62+
# https://github.com/tensorflow/tensorflow/issues/58973
6063
self.models[i].compile(**dict_kwargs)
6164

6265
def __get_confidence(self, model_index: int, input: NDArray) -> NDArray:

tests/test_ensemble.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import tensorflow as tf
44
import numpy as np
5+
import pytest
56

67
thisfile = os.path.abspath(__file__)
78
modulepath = os.path.dirname(os.path.dirname(thisfile))
@@ -11,6 +12,9 @@
1112
from tensorcircuit.applications.ai.ensemble import bagging
1213

1314

15+
@pytest.mark.xfail(
16+
int(tf.__version__.split(".")[1]) >= 16, reason="legacy optimizer fails tf>=2.16"
17+
)
1418
def test_ensemble_bagging():
1519
data_amount = 100 # Amount of data to be used
1620
linear_dimension = 4 # linear demension of the data

0 commit comments

Comments
 (0)