@@ -3720,6 +3720,7 @@ module ts {
3720
3720
for ( var i = 0 ; i < typeParameters . length ; i ++ ) inferences . push ( [ ] ) ;
3721
3721
return {
3722
3722
typeParameters : typeParameters ,
3723
+ inferenceCount : 0 ,
3723
3724
inferences : inferences ,
3724
3725
inferredTypes : new Array ( typeParameters . length ) ,
3725
3726
} ;
@@ -3757,6 +3758,7 @@ module ts {
3757
3758
var typeParameters = context . typeParameters ;
3758
3759
for ( var i = 0 ; i < typeParameters . length ; i ++ ) {
3759
3760
if ( target === typeParameters [ i ] ) {
3761
+ context . inferenceCount ++ ;
3760
3762
var inferences = context . inferences [ i ] ;
3761
3763
if ( ! contains ( inferences , source ) ) inferences . push ( source ) ;
3762
3764
break ;
@@ -3771,6 +3773,35 @@ module ts {
3771
3773
inferFromTypes ( sourceTypes [ i ] , targetTypes [ i ] ) ;
3772
3774
}
3773
3775
}
3776
+ else if ( target . flags & TypeFlags . Union ) {
3777
+ // Target is a union type
3778
+ var targetTypes = ( < UnionType > target ) . types ;
3779
+ var startCount = context . inferenceCount ;
3780
+ var typeParameterCount = 0 ;
3781
+ var typeParameter : TypeParameter ;
3782
+ // First infer to each type in union that isn't a type parameter
3783
+ for ( var i = 0 ; i < targetTypes . length ; i ++ ) {
3784
+ var t = targetTypes [ i ] ;
3785
+ if ( t . flags & TypeFlags . TypeParameter && contains ( context . typeParameters , t ) ) {
3786
+ typeParameter = < TypeParameter > t ;
3787
+ typeParameterCount ++ ;
3788
+ }
3789
+ else {
3790
+ inferFromTypes ( source , t ) ;
3791
+ }
3792
+ }
3793
+ // If no inferences were produced above and union contains a single naked type parameter, infer to that type parameter
3794
+ if ( context . inferenceCount === startCount && typeParameterCount === 1 ) {
3795
+ inferFromTypes ( source , typeParameter ) ;
3796
+ }
3797
+ }
3798
+ else if ( source . flags & TypeFlags . Union ) {
3799
+ // Source is a union type, infer from each consituent type
3800
+ var sourceTypes = ( < UnionType > source ) . types ;
3801
+ for ( var i = 0 ; i < sourceTypes . length ; i ++ ) {
3802
+ inferFromTypes ( sourceTypes [ i ] , target ) ;
3803
+ }
3804
+ }
3774
3805
else if ( source . flags & TypeFlags . ObjectType && ( target . flags & ( TypeFlags . Reference | TypeFlags . Tuple ) ||
3775
3806
( target . flags & TypeFlags . Anonymous ) && target . symbol && target . symbol . flags & ( SymbolFlags . Method | SymbolFlags . TypeLiteral ) ) ) {
3776
3807
// If source is an object type, and target is a type reference, a tuple type, the type of a method, or a type literal, infer from members
@@ -5169,7 +5200,9 @@ module ts {
5169
5200
5170
5201
// Try to return the best common type if we have any return expressions.
5171
5202
if ( types . length > 0 ) {
5172
- var commonType = getCommonSupertype ( types ) ;
5203
+ // When return statements are contextually typed we allow the return type to be a union type. Otherwise we require the
5204
+ // return expressions to have a best common supertype.
5205
+ var commonType = getContextualSignature ( func ) ? getUnionType ( types ) : getCommonSupertype ( types ) ;
5173
5206
if ( ! commonType ) {
5174
5207
error ( func , Diagnostics . No_best_common_type_exists_among_return_expressions ) ;
5175
5208
0 commit comments