Skip to content

Commit e7a32c3

Browse files
committed
allow column vectors in visualize_coefficients
1 parent 0d6ee19 commit e7a32c3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

mglearn/tools.py

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def visualize_coefficients(coefficients, feature_names, n_top_features=25):
2323
positive) and smallest (most negative) n_top_features coefficients,
2424
for a total of 2 * n_top_features coefficients.
2525
"""
26+
if coefficients.ndim > 1 and coefficients.shape[1] > 1:
27+
# this is not a row or column vector
28+
raise ValueError("coeffients must be 1d array or column vector, got"
29+
" shape {}".format(coefficients.shape))
30+
coefficients = coefficients.ravel()
31+
2632
if len(coefficients) != len(feature_names):
2733
raise ValueError("Number of coefficients {} doesn't match number of"
2834
"feature names {}.".format(len(coefficients),

0 commit comments

Comments
 (0)